Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Sktime Pytorch forecasting NBeatsKAN

From Leeroopedia


Knowledge Sources
Domains Time_Series, Forecasting, Deep_Learning
Last Updated 2026-02-08 08:00 GMT

Overview

NBeatsKAN is an N-BEATS variant that replaces standard MLP layers with Kolmogorov-Arnold Network (KAN) layers for improved accuracy and interpretability.

Description

NBeatsKAN extends NBeatsAdapter and implements the N-BEATS architecture with KAN-augmented blocks based on the Kolmogorov-Arnold representation theorem. Instead of using fixed MLP edge weights, the KAN blocks use learnable univariate spline functions on each edge, allowing the network to better capture complex temporal patterns with improved parameter efficiency. The model supports three stack types -- "generic", "trend", and "seasonality" -- using NBEATSGenericBlockKAN, NBEATSTrendBlockKAN, and NBEATSSeasonalBlockKAN respectively. It also includes a update_kan_grid method for updating the KAN spline grids during training.

Usage

Use NBeatsKAN when you want the interpretable structure of N-BEATS (trend/seasonality decomposition) combined with the enhanced expressiveness of KAN layers. This model is particularly useful for univariate time series forecasting where interpretability and parameter efficiency are important, or for zero-shot cross-domain transfer learning scenarios.

Code Reference

Source Location

Signature

class NBeatsKAN(NBeatsAdapter):
    def __init__(
        self,
        stack_types: list[str] | None = None,
        num_blocks: list[int] | None = None,
        num_block_layers: list[int] | None = None,
        widths: list[int] | None = None,
        sharing: list[bool] | None = None,
        expansion_coefficient_lengths: list[int] | None = None,
        prediction_length: int = 1,
        context_length: int = 1,
        dropout: float = 0.1,
        learning_rate: float = 1e-2,
        log_interval: int = -1,
        log_gradient_flow: bool = False,
        log_val_interval: int = None,
        weight_decay: float = 1e-3,
        loss: MultiHorizonMetric = None,
        reduce_on_plateau_patience: int = 1000,
        backcast_loss_ratio: float = 0.0,
        logging_metrics: nn.ModuleList = None,
        num: int = 5,
        k: int = 3,
        noise_scale: float = 0.5,
        scale_base_mu: float = 0.0,
        scale_base_sigma: float = 1.0,
        scale_sp: float = 1.0,
        base_fun: callable = None,
        grid_eps: float = 0.02,
        grid_range: list[int] = None,
        sp_trainable: bool = True,
        sb_trainable: bool = True,
        sparse_init: bool = False,
        **kwargs,
    ):

Import

from pytorch_forecasting.models.nbeats import NBeatsKAN

I/O Contract

Inputs

Name Type Required Description
stack_types None No Stack types: "generic", "trend", or "seasonality". Defaults to ["trend", "seasonality"].
num_blocks None No Number of blocks per stack. Defaults to [3, 3].
num_block_layers None No FC layers per block. Defaults to [3, 3].
widths None No FC layer widths per stack. Defaults to [32, 512].
sharing None No Weight sharing per stack. Defaults to [True, True].
expansion_coefficient_lengths None No Expansion coefficient or polynomial degree per stack. Defaults to [3, 7].
prediction_length int No Forecast horizon. Defaults to 1.
context_length int No Lookback period. Defaults to 1.
dropout float No Dropout rate. Defaults to 0.1.
loss MultiHorizonMetric No Loss function. Defaults to MASE().
backcast_loss_ratio float No Weight of backcast loss. Defaults to 0.0.
num int No KAN grid intervals (G). Defaults to 5.
k int No KAN piecewise polynomial order. Defaults to 3.
noise_scale float No KAN initialization noise scale. Defaults to 0.5.
scale_base_mu float No KAN residual function mean. Defaults to 0.0.
scale_base_sigma float No KAN residual function sigma. Defaults to 1.0.
scale_sp float No KAN spline base function scale. Defaults to 1.0.
base_fun callable No KAN residual function. Defaults to torch.nn.SiLU().
grid_eps float No KAN grid interpolation parameter (0=percentile, 1=uniform). Defaults to 0.02.
grid_range list[int] No KAN grid range. Defaults to [-1, 1].
sp_trainable bool No Whether scale_sp is trainable. Defaults to True.
sb_trainable bool No Whether scale_base is trainable. Defaults to True.
sparse_init bool No Whether to use sparse initialization. Defaults to False.

Outputs

Name Type Description
prediction torch.Tensor Forecast output in target space.
backcast torch.Tensor Backcast reconstruction (when backcast_loss_ratio > 0).

Usage Examples

from pytorch_forecasting.models.nbeats import NBeatsKAN

# Interpretable mode with trend and seasonality stacks
model = NBeatsKAN(
    stack_types=["trend", "seasonality"],
    num_blocks=[3, 3],
    num_block_layers=[3, 3],
    widths=[32, 512],
    expansion_coefficient_lengths=[3, 7],
    prediction_length=24,
    context_length=96,
    dropout=0.1,
    num=5,    # KAN grid intervals
    k=3,      # KAN polynomial order
)

# Generic mode
model = NBeatsKAN(
    stack_types=["generic"],
    num_blocks=[1],
    num_block_layers=[4],
    widths=[512],
    expansion_coefficient_lengths=[32],
    prediction_length=24,
    context_length=96,
)

# Update KAN grids during training
model.update_kan_grid()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment