Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Sktime Pytorch forecasting GridUpdateCallback

From Leeroopedia


Knowledge Sources
Domains Time_Series, Forecasting, Deep_Learning
Last Updated 2026-02-08 08:00 GMT

Overview

GridUpdateCallback is a PyTorch Lightning Callback that periodically updates the grid of KAN (Kolmogorov-Arnold Network) layers during model training at regular intervals.

Description

GridUpdateCallback hooks into the Lightning training loop via the on_train_batch_end event. At every training step that is a multiple of the configured update_interval, it invokes the model's update_kan_grid() method to refresh the KAN layer grids. This ensures that the KAN spline grids remain aligned with the evolving data distribution throughout training.

Usage

Use GridUpdateCallback when training N-BEATS or other models that incorporate KAN layers. Pass it as a callback to the PyTorch Lightning Trainer to ensure KAN grid parameters are periodically refreshed during training, which is critical for maintaining KAN layer accuracy.

Code Reference

Source Location

Signature

class GridUpdateCallback(Callback):
    def __init__(self, update_interval):

Import

from pytorch_forecasting.models.nbeats._grid_callback import GridUpdateCallback

I/O Contract

Inputs

Name Type Required Description
update_interval int Yes The frequency (in training steps) at which the KAN grid is updated

Outputs

Name Type Description
(side effect) None Calls pl_module.update_kan_grid() at each interval; no return value

Usage Examples

from lightning.pytorch import Trainer
from pytorch_forecasting.models.nbeats._grid_callback import GridUpdateCallback

# Create the callback that updates KAN grids every 50 training steps
grid_callback = GridUpdateCallback(update_interval=50)

# Pass it to the Lightning Trainer
trainer = Trainer(
    max_epochs=100,
    callbacks=[grid_callback],
)

# Train your KAN-based model
trainer.fit(model, train_dataloaders=train_dataloader)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment