Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Sktime Pytorch forecasting Trainer Fit

From Leeroopedia


Knowledge Sources
Domains Deep_Learning, Training, Optimization
Last Updated 2026-02-08 07:00 GMT

Overview

Wrapper documentation for the PyTorch Lightning Trainer.fit method as used in pytorch-forecasting workflows to execute model training.

Description

The Trainer.fit method executes the full training loop for a pytorch-forecasting model. It iterates over epochs, running training steps (via BaseModel.training_step at L751-808) and validation steps. For each training batch, the model computes the forward pass, calculates loss (QuantileLoss, NormalDistributionLoss, or MASE depending on model), applies gradient clipping, and updates parameters. Callbacks like EarlyStopping and LearningRateMonitor execute at appropriate hooks. The method is external to pytorch-forecasting (from Lightning) but integrates deeply with the library's BaseModel training infrastructure.

Usage

Call after all setup is complete: Trainer configured, model instantiated, DataLoaders created, and optionally learning rate found. This is the final training step in all four workflows.

Code Reference

Source Location

  • Repository: External — pytorch-lightning
  • File: lightning/pytorch/trainer/trainer.py
  • Related internal code: pytorch_forecasting/models/base/_base_model.py L751-808 (training_step)

Signature

class Trainer:
    def fit(
        self,
        model: LightningModule,
        train_dataloaders: DataLoader | LightningDataModule | None = None,
        val_dataloaders: DataLoader | None = None,
        datamodule: LightningDataModule | None = None,
        ckpt_path: str | None = None,
    ) -> None:
        """
        Runs the full optimization routine.

        Args:
            model: Model to fit.
            train_dataloaders: Training DataLoader(s).
            val_dataloaders: Validation DataLoader(s).
            datamodule: Optional LightningDataModule.
            ckpt_path: Path to checkpoint to resume from.
        """

Import

import lightning.pytorch as pl
# trainer = pl.Trainer(...)
# trainer.fit(model, train_dataloaders, val_dataloaders)

External Reference

I/O Contract

Inputs

Name Type Required Description
model LightningModule Yes pytorch-forecasting model (TFT, DeepAR, NBeats)
train_dataloaders DataLoader Yes Training DataLoader from TimeSeriesDataSet.to_dataloader()
val_dataloaders DataLoader No Validation DataLoader
ckpt_path str No Path to resume training from a checkpoint

Outputs

Name Type Description
(side effect) None Model weights are updated in-place; checkpoints saved to disk; logs written

Usage Examples

Standard Model Training

# After trainer, model, and dataloaders are configured:
trainer.fit(
    model,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
)

# Access best model path after training
best_model_path = trainer.checkpoint_callback.best_model_path
print(f"Best model saved at: {best_model_path}")

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment