Overview
Base_pkg is a high-level wrapper class that manages the Lightning model, data module, and trainer lifecycle for the v2 pytorch-forecasting architecture.
Description
Base_pkg extends _BasePtForecasterV2 and acts as a forecaster package container that simplifies the user experience by managing model configuration, data module configuration, and trainer configuration in a unified interface. It provides streamlined fit and predict methods that handle building the model from config or checkpoint, constructing data modules from TimeSeries datasets, setting up Lightning Trainer with optional checkpointing, and generating predictions from various data input types (TimeSeries, LightningDataModule, or DataLoader). The class also supports saving and loading configuration artifacts (model_cfg, datamodule_cfg, metadata) as pickle files alongside checkpoints.
Usage
Subclass Base_pkg for specific model architectures by implementing get_cls and get_datamodule_cls. Use it when you want a single entry point that handles the full training-to-prediction workflow without manually managing the Lightning Trainer, DataModule, and Model separately.
Code Reference
Source Location
Signature
class Base_pkg(_BasePtForecasterV2):
def __init__(
self,
model_cfg: dict[str, Any] | str | Path | None = None,
trainer_cfg: dict[str, Any] | str | Path | None = None,
datamodule_cfg: dict[str, Any] | str | Path | None = None,
ckpt_path: str | Path | None = None,
):
fit
def fit(
self,
data: TimeSeries | LightningDataModule,
save_ckpt: bool = True,
ckpt_dir: str | Path = "checkpoints",
ckpt_kwargs: dict[str, Any] | None = None,
**trainer_fit_kwargs,
):
predict
def predict(
self,
data: TimeSeries | LightningDataModule | DataLoader,
output_dir: str | Path | None = None,
**kwargs,
) -> dict[str, torch.Tensor] | None:
Import
from pytorch_forecasting.base._base_pkg import Base_pkg
I/O Contract
Constructor Inputs
| Name |
Type |
Required |
Description
|
| model_cfg |
dict or str or Path or None |
No |
Model configuration dictionary or path to YAML/PKL file. Ignored if ckpt_path is provided.
|
| trainer_cfg |
dict or str or Path or None |
No |
Configuration for Lightning Trainer initialization
|
| datamodule_cfg |
dict or str or Path or None |
No |
Configuration for LightningDataModule initialization
|
| ckpt_path |
str or Path or None |
No |
Path to checkpoint file for loading a pre-trained model
|
fit Inputs
| Name |
Type |
Required |
Description
|
| data |
TimeSeries or LightningDataModule |
Yes |
Training data at D1 (TimeSeries) or D2 (DataModule) layer
|
| save_ckpt |
bool |
No |
Whether to save best model checkpoint (default True)
|
| ckpt_dir |
str or Path |
No |
Directory for saving artifacts (default "checkpoints")
|
| ckpt_kwargs |
dict or None |
No |
Additional arguments passed to ModelCheckpoint
|
fit Outputs
| Name |
Type |
Description
|
| best_model_path |
Path or None |
Path to the best model checkpoint if save_ckpt=True, else None
|
predict Inputs
| Name |
Type |
Required |
Description
|
| data |
TimeSeries or LightningDataModule or DataLoader |
Yes |
Data to predict on (D1, D2, or raw DataLoader)
|
| output_dir |
str or Path or None |
No |
If provided, saves predictions to this directory as predictions.pkl
|
predict Outputs
| Name |
Type |
Description
|
| predictions |
dict[str, torch.Tensor] or None |
Dictionary of prediction tensors, or None if output_dir is specified
|
Usage Examples
from pytorch_forecasting.base._base_pkg import Base_pkg
# Subclass for a specific model
class MyModel_pkg(Base_pkg):
@classmethod
def get_cls(cls):
return MyLightningModel
@classmethod
def get_datamodule_cls(cls):
return MyDataModule
# Train from config
pkg = MyModel_pkg(
model_cfg={"hidden_size": 256, "loss": MAE()},
trainer_cfg={"max_epochs": 10, "accelerator": "auto"},
datamodule_cfg={"context_length": 96, "prediction_length": 24},
)
best_ckpt = pkg.fit(training_data, save_ckpt=True)
# Predict
predictions = pkg.predict(test_data)
# Load from checkpoint
pkg_loaded = MyModel_pkg(ckpt_path="checkpoints/best-epoch=5-step=1000.ckpt")
predictions = pkg_loaded.predict(test_data)
Related Pages