Implementation:Sktime Pytorch forecasting GridUpdateCallback
| Knowledge Sources | |
|---|---|
| Domains | Time_Series, Forecasting, Deep_Learning |
| Last Updated | 2026-02-08 08:00 GMT |
Overview
GridUpdateCallback is a PyTorch Lightning Callback that periodically updates the grid of KAN (Kolmogorov-Arnold Network) layers during model training at regular intervals.
Description
GridUpdateCallback hooks into the Lightning training loop via the on_train_batch_end event. At every training step that is a multiple of the configured update_interval, it invokes the model's update_kan_grid() method to refresh the KAN layer grids. This ensures that the KAN spline grids remain aligned with the evolving data distribution throughout training.
Usage
Use GridUpdateCallback when training N-BEATS or other models that incorporate KAN layers. Pass it as a callback to the PyTorch Lightning Trainer to ensure KAN grid parameters are periodically refreshed during training, which is critical for maintaining KAN layer accuracy.
Code Reference
Source Location
- Repository: Sktime_Pytorch_forecasting
- File: pytorch_forecasting/models/nbeats/_grid_callback.py
- Lines: 1-46
Signature
class GridUpdateCallback(Callback):
def __init__(self, update_interval):
Import
from pytorch_forecasting.models.nbeats._grid_callback import GridUpdateCallback
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| update_interval | int | Yes | The frequency (in training steps) at which the KAN grid is updated |
Outputs
| Name | Type | Description |
|---|---|---|
| (side effect) | None | Calls pl_module.update_kan_grid() at each interval; no return value |
Usage Examples
from lightning.pytorch import Trainer
from pytorch_forecasting.models.nbeats._grid_callback import GridUpdateCallback
# Create the callback that updates KAN grids every 50 training steps
grid_callback = GridUpdateCallback(update_interval=50)
# Pass it to the Lightning Trainer
trainer = Trainer(
max_epochs=100,
callbacks=[grid_callback],
)
# Train your KAN-based model
trainer.fit(model, train_dataloaders=train_dataloader)