Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Sktime Pytorch forecasting Base Pkg

From Leeroopedia
Revision as of 16:41, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Sktime_Pytorch_forecasting_Base_Pkg.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Time_Series, Forecasting, Deep_Learning
Last Updated 2026-02-08 08:00 GMT

Overview

Base_pkg is a high-level wrapper class that manages the Lightning model, data module, and trainer lifecycle for the v2 pytorch-forecasting architecture.

Description

Base_pkg extends _BasePtForecasterV2 and acts as a forecaster package container that simplifies the user experience by managing model configuration, data module configuration, and trainer configuration in a unified interface. It provides streamlined fit and predict methods that handle building the model from config or checkpoint, constructing data modules from TimeSeries datasets, setting up Lightning Trainer with optional checkpointing, and generating predictions from various data input types (TimeSeries, LightningDataModule, or DataLoader). The class also supports saving and loading configuration artifacts (model_cfg, datamodule_cfg, metadata) as pickle files alongside checkpoints.

Usage

Subclass Base_pkg for specific model architectures by implementing get_cls and get_datamodule_cls. Use it when you want a single entry point that handles the full training-to-prediction workflow without manually managing the Lightning Trainer, DataModule, and Model separately.

Code Reference

Source Location

Signature

class Base_pkg(_BasePtForecasterV2):
    def __init__(
        self,
        model_cfg: dict[str, Any] | str | Path | None = None,
        trainer_cfg: dict[str, Any] | str | Path | None = None,
        datamodule_cfg: dict[str, Any] | str | Path | None = None,
        ckpt_path: str | Path | None = None,
    ):

fit

def fit(
    self,
    data: TimeSeries | LightningDataModule,
    save_ckpt: bool = True,
    ckpt_dir: str | Path = "checkpoints",
    ckpt_kwargs: dict[str, Any] | None = None,
    **trainer_fit_kwargs,
):

predict

def predict(
    self,
    data: TimeSeries | LightningDataModule | DataLoader,
    output_dir: str | Path | None = None,
    **kwargs,
) -> dict[str, torch.Tensor] | None:

Import

from pytorch_forecasting.base._base_pkg import Base_pkg

I/O Contract

Constructor Inputs

Name Type Required Description
model_cfg dict or str or Path or None No Model configuration dictionary or path to YAML/PKL file. Ignored if ckpt_path is provided.
trainer_cfg dict or str or Path or None No Configuration for Lightning Trainer initialization
datamodule_cfg dict or str or Path or None No Configuration for LightningDataModule initialization
ckpt_path str or Path or None No Path to checkpoint file for loading a pre-trained model

fit Inputs

Name Type Required Description
data TimeSeries or LightningDataModule Yes Training data at D1 (TimeSeries) or D2 (DataModule) layer
save_ckpt bool No Whether to save best model checkpoint (default True)
ckpt_dir str or Path No Directory for saving artifacts (default "checkpoints")
ckpt_kwargs dict or None No Additional arguments passed to ModelCheckpoint

fit Outputs

Name Type Description
best_model_path Path or None Path to the best model checkpoint if save_ckpt=True, else None

predict Inputs

Name Type Required Description
data TimeSeries or LightningDataModule or DataLoader Yes Data to predict on (D1, D2, or raw DataLoader)
output_dir str or Path or None No If provided, saves predictions to this directory as predictions.pkl

predict Outputs

Name Type Description
predictions dict[str, torch.Tensor] or None Dictionary of prediction tensors, or None if output_dir is specified

Usage Examples

from pytorch_forecasting.base._base_pkg import Base_pkg

# Subclass for a specific model
class MyModel_pkg(Base_pkg):
    @classmethod
    def get_cls(cls):
        return MyLightningModel

    @classmethod
    def get_datamodule_cls(cls):
        return MyDataModule

# Train from config
pkg = MyModel_pkg(
    model_cfg={"hidden_size": 256, "loss": MAE()},
    trainer_cfg={"max_epochs": 10, "accelerator": "auto"},
    datamodule_cfg={"context_length": 96, "prediction_length": 24},
)
best_ckpt = pkg.fit(training_data, save_ckpt=True)

# Predict
predictions = pkg.predict(test_data)

# Load from checkpoint
pkg_loaded = MyModel_pkg(ckpt_path="checkpoints/best-epoch=5-step=1000.ckpt")
predictions = pkg_loaded.predict(test_data)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment