Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Principle:Sktime Pytorch forecasting V2 Model Base

From Leeroopedia


Knowledge Sources
Domains Time_Series, Forecasting, Deep_Learning, Software_Architecture
Last Updated 2026-02-08 09:00 GMT

Overview

The V2 model base class hierarchy defines the foundational architecture for all experimental pytorch-forecasting v2 models. It comprises three layers: BaseModel (a LightningModule providing training, validation, and prediction lifecycle), TslibBaseModel (an adapter that extracts metadata-driven feature dimensions for tslib-style transformer models), and Base_pkg (an sktime-compatible forecaster package interface that wraps the entire Lightning workflow).

Description

BaseModel (V2): The root base class inherits from LightningModule and provides the full training lifecycle for forecasting models. It accepts a loss function (any pytorch_forecasting.metrics.Metric descendant), optional logging metrics, and optimizer/scheduler configuration specified either as string shortcuts (adam, sgd) or direct Optimizer instances. The class implements training_step, validation_step, and test_step methods that follow a uniform pattern: unpack the batch into (x, y), run the forward pass to get a prediction dictionary, compute loss via self.loss(y_hat, y), log the loss, and invoke log_metrics to record all auxiliary metrics. The predict method orchestrates inference using a Lightning Trainer and a PredictCallback that collects results in a specified mode (prediction, quantiles, or raw). The class also provides to_prediction and to_quantiles methods that delegate to the loss function for output conversion. Optimizer and LR scheduler configuration is handled via configure_optimizers, supporting ReduceLROnPlateau and StepLR schedulers. A pkg class property provides access to the corresponding package class for sktime integration.

TslibBaseModel: This class extends BaseModel specifically for tslib-style models. During initialization, it reads metadata from the TslibDataModule to extract feature indices (continuous, categorical, known, unknown, target), feature dimensions, feature names, context and prediction lengths, and the feature mode (S/MS/M). It saves hyperparameters (excluding non-serializable objects like loss and metrics) for Lightning checkpoint compatibility. Subclasses must implement _init_network to define their specific architecture and forward to produce prediction dictionaries. The class also provides a transform_output method for inverse-scaling predictions back to original scale using stored center and scale parameters. Its predict_step passes the full input dictionary alongside predictions so downstream callbacks can access original targets and features.

Base_pkg: This class implements the _BasePtForecasterV2 interface, providing a high-level package wrapper that simplifies the Lightning workflow into fit and predict methods. It manages three configuration dictionaries: model_cfg, trainer_cfg, and datamodule_cfg, all of which can be loaded from dictionaries, YAML files, or pickle files. The fit method accepts either a D1 TimeSeries dataset or a D2 LightningDataModule, builds the data module and model, configures a ModelCheckpoint callback, runs training via a Lightning Trainer, and saves configuration artifacts alongside the best checkpoint for reproducibility. The predict method accepts D1, D2, or raw DataLoader inputs, resolves them into a prediction dataloader, and delegates to the model. Loading from checkpoints is supported by reading artifact files (model_cfg.pkl, datamodule_cfg.pkl, metadata.pkl) from the checkpoint directory. Subclasses must implement get_cls (returning the Lightning model class) and get_datamodule_cls (returning the data module class).

Usage

Use BaseModel as the parent class for any new V2 forecasting model that needs the standard training/validation/test lifecycle with configurable loss and optimizer. Use TslibBaseModel when building a model that integrates with the tslib-style data pipeline and needs automatic metadata-driven layer sizing. Use Base_pkg to wrap a model into a user-friendly package with simple fit/predict methods, checkpoint management, and compatibility with the sktime estimator interface.

Theoretical Basis

Model Lifecycle Pattern:

The V2 model base follows the standard Lightning training pattern with forecasting-specific adaptations:

# Unified training step pattern
def training_step(batch, batch_idx):
    x, y = batch                     # Unpack dict inputs and targets
    y_hat_dict = model.forward(x)    # Forward pass returns dict
    y_hat = y_hat_dict["prediction"] # Extract prediction tensor
    loss = loss_fn(y_hat, y)         # Compute loss
    log_metrics(y_hat, y)            # Log auxiliary metrics
    return loss

Metadata-Driven Architecture Configuration:

TslibBaseModel uses metadata from the data module to automatically configure layer dimensions:

# Pseudo-code: metadata-driven model initialization
metadata = data_module.metadata
cont_dim = metadata["n_features"]["continuous"]
cat_dim = metadata["n_features"]["categorical"]
target_dim = metadata["n_features"]["target"]
context_length = metadata["context_length"]
prediction_length = metadata["prediction_length"]

# Model layers are sized according to metadata
encoder_layer = Linear(cont_dim + cat_dim, d_model)
decoder_layer = Linear(target_dim, d_model)

Package Pattern (fit/predict):

# Pseudo-code: simplified user workflow via Base_pkg
pkg = ModelPackage(
    model_cfg={"d_model": 64, "n_heads": 4},
    datamodule_cfg={"context_length": 96, "prediction_length": 24},
    trainer_cfg={"max_epochs": 10},
)

# fit accepts D1 TimeSeries or D2 DataModule
best_ckpt = pkg.fit(training_data, save_ckpt=True)

# predict accepts D1, D2, or raw DataLoader
predictions = pkg.predict(test_data)

Inverse Scaling:

For models that operate on normalized data, predictions are transformed back to the original scale:

y^original=y^normalized×σ+μ

where σ is the scale factor and μ is the center (mean).

Related Pages

Implemented By

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment