Implementation:Sktime Pytorch forecasting Tuner Lr Find
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, Optimization, Hyperparameter_Tuning |
| Last Updated | 2026-02-08 07:00 GMT |
Overview
Wrapper tool for performing learning rate range tests using the pytorch-forecasting Tuner, which extends Lightning's Tuner with checkpoint compatibility fixes.
Description
The Tuner class in pytorch-forecasting wraps Lightning's Tuner.lr_find with a compatibility fix: it monkey-patches the checkpoint loading strategy to use weights_only=False for Lightning >= 2.6, where the default changed. The lr_find method trains the model for a short period with exponentially increasing learning rates and returns a result object with a suggestion() method for the optimal rate and a plot() method for visualization.
Usage
Call after creating a Trainer and model, before Trainer.fit(). Pass both training and validation DataLoaders. The suggested learning rate should be assigned to the model before training.
Code Reference
Source Location
- Repository: pytorch-forecasting
- File: pytorch_forecasting/tuning/tuner.py
- Lines: L9-25
Signature
class Tuner(lightning.pytorch.tuner.Tuner):
def lr_find(self, *args, **kwargs):
"""
Wrapper around Lightning Tuner.lr_find with checkpoint loading fix.
Args:
model: LightningModule to tune
train_dataloaders: training DataLoader
val_dataloaders: validation DataLoader
min_lr: minimum learning rate (default: 1e-8)
max_lr: maximum learning rate (default: 1)
num_training: number of training steps (default: 100)
mode: "exponential" or "linear" (default: "exponential")
early_stop_threshold: stop if loss > threshold * best_loss
Returns:
LRFinderResult with .suggestion() and .plot() methods
"""
Import
from pytorch_forecasting.tuning import Tuner
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| trainer | pl.Trainer | Yes (via constructor) | Configured Lightning Trainer |
| model | LightningModule | Yes | Model to find learning rate for |
| train_dataloaders | DataLoader | Yes | Training data |
| val_dataloaders | DataLoader | Yes | Validation data |
| min_lr | float | No | Minimum LR to test (default: 1e-8) |
| max_lr | float | No | Maximum LR to test (default: 1) |
Outputs
| Name | Type | Description |
|---|---|---|
| return | LRFinderResult | Result with .suggestion() returning optimal float LR, .plot() for visualization |
Usage Examples
TFT Learning Rate Finding
from pytorch_forecasting.tuning import Tuner
# Create Tuner with trainer
res = Tuner(trainer).lr_find(
tft,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
min_lr=1e-5,
max_lr=1e2,
)
# Get suggested learning rate
optimal_lr = res.suggestion()
print(f"Suggested learning rate: {optimal_lr:.2e}")
# Plot loss vs learning rate curve
fig = res.plot(show=True, suggest=True)