Overview
TslibBaseModel is the base class for tslib-style forecasting models, extending BaseModel with metadata-driven feature configuration.
Description
TslibBaseModel extends BaseModel and provides a standardized foundation for models ported from the Time Series Library (tslib). It parses a metadata dictionary to extract context length, prediction length, feature indices (continuous, categorical, known, unknown, target), feature dimensions, and feature names. The class emits an experimental warning at initialization, as it is part of the v2 API under development. Subclasses must implement _init_network and forward methods. It also provides predict_step and transform_output methods for inference and inverse scaling.
Usage
Use TslibBaseModel as the parent class when implementing new tslib-style forecasting models within pytorch-forecasting. It handles common initialization of metadata, feature indices, optimizer setup, and output transformation, allowing subclasses to focus on defining the network architecture and forward pass.
Code Reference
Source Location
Signature
class TslibBaseModel(BaseModel):
def __init__(
self,
loss: Metric,
logging_metrics: list[nn.Module] | None = None,
optimizer: Optimizer | str | None = "adam",
optimizer_params: dict | None = None,
lr_scheduler: str | None = None,
lr_scheduler_params: dict | None = None,
metadata: dict | None = None,
):
def _init_network(self):
def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
def predict_step(
self,
batch: tuple[dict[str, torch.Tensor]],
batch_idx: int,
dataloader_idx: int = 0,
) -> torch.Tensor:
def transform_output(
self,
y_hat: torch.Tensor | list[torch.Tensor],
target_scale: dict[str, torch.Tensor] | None,
) -> torch.Tensor | list[torch.Tensor]:
Import
from pytorch_forecasting.models.base._tslib_base_model_v2 import TslibBaseModel
I/O Contract
Inputs
__init__
| Name |
Type |
Required |
Description
|
| loss |
Metric |
Yes |
Loss function to use for training (descendant of pytorch_forecasting.metrics.Metric).
|
| logging_metrics |
list[nn.Module] or None |
No |
List of metrics to log during training, validation, and testing. Defaults to None.
|
| optimizer |
Optimizer or str or None |
No |
Optimizer to use for training. Defaults to "adam".
|
| optimizer_params |
dict or None |
No |
Parameters for the optimizer. Defaults to None.
|
| lr_scheduler |
str or None |
No |
Learning rate scheduler to use. Defaults to None.
|
| lr_scheduler_params |
dict or None |
No |
Parameters for the learning rate scheduler. Defaults to None.
|
| metadata |
dict or None |
No |
Metadata dictionary from TslibDataModule containing context_length, prediction_length, feature_indices, n_features, feature_names, and features mode. Defaults to None.
|
forward
| Name |
Type |
Required |
Description
|
| x |
dict[str, torch.Tensor] |
Yes |
Dictionary containing input tensors.
|
predict_step
| Name |
Type |
Required |
Description
|
| batch |
tuple[dict[str, torch.Tensor]] |
Yes |
Batch of data containing input tensors.
|
| batch_idx |
int |
Yes |
Index of the batch.
|
| dataloader_idx |
int |
No |
Index of the dataloader. Defaults to 0.
|
transform_output
| Name |
Type |
Required |
Description
|
| y_hat |
torch.Tensor or list[torch.Tensor] |
Yes |
Model output predictions to be inverse-transformed.
|
| target_scale |
dict[str, torch.Tensor] or None |
Yes |
Dictionary containing "scale" and "center" tensors for inverse transformation.
|
Outputs
forward
| Name |
Type |
Description
|
| output |
dict[str, torch.Tensor] |
Dictionary containing output tensors, including "predictions" of shape (batch_size, prediction_length, target_dim) and optionally "attention_weights".
|
predict_step
| Name |
Type |
Description
|
| y_hat |
dict[str, torch.Tensor] |
Dictionary containing predicted output tensor and optionally the original target.
|
transform_output
| Name |
Type |
Description
|
| output |
torch.Tensor or list[torch.Tensor] |
Inverse-transformed output: y_hat * scale + center.
|
Usage Examples
from pytorch_forecasting.models.base._tslib_base_model_v2 import TslibBaseModel
from pytorch_forecasting.metrics import MAE
# TslibBaseModel is abstract; subclass it to define a concrete model
class MyTslibModel(TslibBaseModel):
def _init_network(self):
# define model layers here
pass
def forward(self, x):
# implement forward pass
pass
metadata = {
"context_length": 96,
"prediction_length": 24,
"feature_indices": {"continuous": [0, 1], "target": [0]},
"n_features": {"continuous": 2, "target": 1},
"features": "MS",
}
model = MyTslibModel(loss=MAE(), metadata=metadata)
Related Pages