Implementation:Sktime Pytorch forecasting NBeatsKAN
| Knowledge Sources | |
|---|---|
| Domains | Time_Series, Forecasting, Deep_Learning |
| Last Updated | 2026-02-08 08:00 GMT |
Overview
NBeatsKAN is an N-BEATS variant that replaces standard MLP layers with Kolmogorov-Arnold Network (KAN) layers for improved accuracy and interpretability.
Description
NBeatsKAN extends NBeatsAdapter and implements the N-BEATS architecture with KAN-augmented blocks based on the Kolmogorov-Arnold representation theorem. Instead of using fixed MLP edge weights, the KAN blocks use learnable univariate spline functions on each edge, allowing the network to better capture complex temporal patterns with improved parameter efficiency. The model supports three stack types -- "generic", "trend", and "seasonality" -- using NBEATSGenericBlockKAN, NBEATSTrendBlockKAN, and NBEATSSeasonalBlockKAN respectively. It also includes a update_kan_grid method for updating the KAN spline grids during training.
Usage
Use NBeatsKAN when you want the interpretable structure of N-BEATS (trend/seasonality decomposition) combined with the enhanced expressiveness of KAN layers. This model is particularly useful for univariate time series forecasting where interpretability and parameter efficiency are important, or for zero-shot cross-domain transfer learning scenarios.
Code Reference
Source Location
- Repository: Sktime_Pytorch_forecasting
- File: pytorch_forecasting/models/nbeats/_nbeatskan.py
- Lines: 1-287
Signature
class NBeatsKAN(NBeatsAdapter):
def __init__(
self,
stack_types: list[str] | None = None,
num_blocks: list[int] | None = None,
num_block_layers: list[int] | None = None,
widths: list[int] | None = None,
sharing: list[bool] | None = None,
expansion_coefficient_lengths: list[int] | None = None,
prediction_length: int = 1,
context_length: int = 1,
dropout: float = 0.1,
learning_rate: float = 1e-2,
log_interval: int = -1,
log_gradient_flow: bool = False,
log_val_interval: int = None,
weight_decay: float = 1e-3,
loss: MultiHorizonMetric = None,
reduce_on_plateau_patience: int = 1000,
backcast_loss_ratio: float = 0.0,
logging_metrics: nn.ModuleList = None,
num: int = 5,
k: int = 3,
noise_scale: float = 0.5,
scale_base_mu: float = 0.0,
scale_base_sigma: float = 1.0,
scale_sp: float = 1.0,
base_fun: callable = None,
grid_eps: float = 0.02,
grid_range: list[int] = None,
sp_trainable: bool = True,
sb_trainable: bool = True,
sparse_init: bool = False,
**kwargs,
):
Import
from pytorch_forecasting.models.nbeats import NBeatsKAN
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| stack_types | None | No | Stack types: "generic", "trend", or "seasonality". Defaults to ["trend", "seasonality"]. |
| num_blocks | None | No | Number of blocks per stack. Defaults to [3, 3]. |
| num_block_layers | None | No | FC layers per block. Defaults to [3, 3]. |
| widths | None | No | FC layer widths per stack. Defaults to [32, 512]. |
| sharing | None | No | Weight sharing per stack. Defaults to [True, True]. |
| expansion_coefficient_lengths | None | No | Expansion coefficient or polynomial degree per stack. Defaults to [3, 7]. |
| prediction_length | int | No | Forecast horizon. Defaults to 1. |
| context_length | int | No | Lookback period. Defaults to 1. |
| dropout | float | No | Dropout rate. Defaults to 0.1. |
| loss | MultiHorizonMetric | No | Loss function. Defaults to MASE(). |
| backcast_loss_ratio | float | No | Weight of backcast loss. Defaults to 0.0. |
| num | int | No | KAN grid intervals (G). Defaults to 5. |
| k | int | No | KAN piecewise polynomial order. Defaults to 3. |
| noise_scale | float | No | KAN initialization noise scale. Defaults to 0.5. |
| scale_base_mu | float | No | KAN residual function mean. Defaults to 0.0. |
| scale_base_sigma | float | No | KAN residual function sigma. Defaults to 1.0. |
| scale_sp | float | No | KAN spline base function scale. Defaults to 1.0. |
| base_fun | callable | No | KAN residual function. Defaults to torch.nn.SiLU(). |
| grid_eps | float | No | KAN grid interpolation parameter (0=percentile, 1=uniform). Defaults to 0.02. |
| grid_range | list[int] | No | KAN grid range. Defaults to [-1, 1]. |
| sp_trainable | bool | No | Whether scale_sp is trainable. Defaults to True. |
| sb_trainable | bool | No | Whether scale_base is trainable. Defaults to True. |
| sparse_init | bool | No | Whether to use sparse initialization. Defaults to False. |
Outputs
| Name | Type | Description |
|---|---|---|
| prediction | torch.Tensor | Forecast output in target space. |
| backcast | torch.Tensor | Backcast reconstruction (when backcast_loss_ratio > 0). |
Usage Examples
from pytorch_forecasting.models.nbeats import NBeatsKAN
# Interpretable mode with trend and seasonality stacks
model = NBeatsKAN(
stack_types=["trend", "seasonality"],
num_blocks=[3, 3],
num_block_layers=[3, 3],
widths=[32, 512],
expansion_coefficient_lengths=[3, 7],
prediction_length=24,
context_length=96,
dropout=0.1,
num=5, # KAN grid intervals
k=3, # KAN polynomial order
)
# Generic mode
model = NBeatsKAN(
stack_types=["generic"],
num_blocks=[1],
num_block_layers=[4],
widths=[512],
expansion_coefficient_lengths=[32],
prediction_length=24,
context_length=96,
)
# Update KAN grids during training
model.update_kan_grid()
Related Pages
- Principle:Sktime_Pytorch_forecasting_KAN_Architecture
- Sktime_Pytorch_forecasting_NHiTS - Related hierarchical interpolation model that outperforms standard N-BEATS