Implementation:Sktime Pytorch forecasting Trainer Fit
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, Training, Optimization |
| Last Updated | 2026-02-08 07:00 GMT |
Overview
Wrapper documentation for the PyTorch Lightning Trainer.fit method as used in pytorch-forecasting workflows to execute model training.
Description
The Trainer.fit method executes the full training loop for a pytorch-forecasting model. It iterates over epochs, running training steps (via BaseModel.training_step at L751-808) and validation steps. For each training batch, the model computes the forward pass, calculates loss (QuantileLoss, NormalDistributionLoss, or MASE depending on model), applies gradient clipping, and updates parameters. Callbacks like EarlyStopping and LearningRateMonitor execute at appropriate hooks. The method is external to pytorch-forecasting (from Lightning) but integrates deeply with the library's BaseModel training infrastructure.
Usage
Call after all setup is complete: Trainer configured, model instantiated, DataLoaders created, and optionally learning rate found. This is the final training step in all four workflows.
Code Reference
Source Location
- Repository: External — pytorch-lightning
- File: lightning/pytorch/trainer/trainer.py
- Related internal code: pytorch_forecasting/models/base/_base_model.py L751-808 (training_step)
Signature
class Trainer:
def fit(
self,
model: LightningModule,
train_dataloaders: DataLoader | LightningDataModule | None = None,
val_dataloaders: DataLoader | None = None,
datamodule: LightningDataModule | None = None,
ckpt_path: str | None = None,
) -> None:
"""
Runs the full optimization routine.
Args:
model: Model to fit.
train_dataloaders: Training DataLoader(s).
val_dataloaders: Validation DataLoader(s).
datamodule: Optional LightningDataModule.
ckpt_path: Path to checkpoint to resume from.
"""
Import
import lightning.pytorch as pl
# trainer = pl.Trainer(...)
# trainer.fit(model, train_dataloaders, val_dataloaders)
External Reference
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | LightningModule | Yes | pytorch-forecasting model (TFT, DeepAR, NBeats) |
| train_dataloaders | DataLoader | Yes | Training DataLoader from TimeSeriesDataSet.to_dataloader() |
| val_dataloaders | DataLoader | No | Validation DataLoader |
| ckpt_path | str | No | Path to resume training from a checkpoint |
Outputs
| Name | Type | Description |
|---|---|---|
| (side effect) | None | Model weights are updated in-place; checkpoints saved to disk; logs written |
Usage Examples
Standard Model Training
# After trainer, model, and dataloaders are configured:
trainer.fit(
model,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)
# Access best model path after training
best_model_path = trainer.checkpoint_callback.best_model_path
print(f"Best model saved at: {best_model_path}")