Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Sktime Pytorch forecasting KANLayer

From Leeroopedia


Knowledge Sources
Domains Time_Series, Forecasting, Deep_Learning
Last Updated 2026-02-08 08:00 GMT

Overview

KANLayer implements a single layer of a Kolmogorov-Arnold Network (KAN) using learnable B-spline activation functions combined with a residual base function.

Description

The KANLayer class replaces traditional fixed activation functions with learnable univariate spline functions on each edge of the network. For each input-output pair, it computes a B-spline curve (via coef2curve) scaled by a trainable scale_sp parameter, adds a residual component computed by a base function (default: SiLU) scaled by scale_base, and applies a sparse connection mask. The spline grid can be adaptively updated from input samples using update_grid_from_samples, which interpolates between uniform and percentile-based grid placement controlled by grid_eps. This implementation is inspired by the pykan library.

Usage

Use KANLayer when building KAN-based architectures for time series forecasting where learnable activation functions are desired over standard fixed activations (ReLU, GELU, etc.). It is particularly useful in scenarios where the function to be learned has complex, non-standard nonlinear structure that benefits from adaptive B-spline representations.

Code Reference

Source Location

Signature

class KANLayer(nn.Module):
    def __init__(
        self,
        in_dim=3,
        out_dim=2,
        num=5,
        k=3,
        noise_scale=0.5,
        scale_base_mu=0.0,
        scale_base_sigma=1.0,
        scale_sp=1.0,
        base_fun=None,
        grid_eps=0.02,
        grid_range=None,
        sp_trainable=True,
        sb_trainable=True,
        sparse_init=False,
    ):
        ...

    def forward(self, x):
        ...

    def update_grid_from_samples(self, x):
        ...

Import

from pytorch_forecasting.layers import KANLayer

I/O Contract

Inputs

__init__ Parameters

Name Type Required Description
in_dim int No Input dimension. Default: 3.
out_dim int No Output dimension. Default: 2.
num int No Number of grid intervals (G). Default: 5.
k int No Order of the piecewise polynomial spline. Default: 3.
noise_scale float No Scale of noise injected at initialization. Default: 0.5.
scale_base_mu float No Mean of the initialization distribution for the residual function scale. Default: 0.0.
scale_base_sigma float No Standard deviation of the initialization distribution for the residual function scale. Default: 1.0.
scale_sp float No Scale of the base spline function. Default: 1.0.
base_fun callable No Residual function b(x). Default: torch.nn.SiLU().
grid_eps float No Interpolation factor between uniform (1.0) and percentile-based (0.0) grid placement. Default: 0.02.
grid_range list No Range of the grid as [min, max]. Default: [-1, 1].
sp_trainable bool No If True, scale_sp is trainable. Default: True.
sb_trainable bool No If True, scale_base is trainable. Default: True.
sparse_init bool No If True, sparse initialization is applied via a connection mask. Default: False.

forward Parameters

Name Type Required Description
x torch.Tensor Yes Input tensor of shape (batch_size, in_dim).

update_grid_from_samples Parameters

Name Type Required Description
x torch.Tensor Yes Input samples of shape (num_samples, in_dim) used to adaptively update the grid.

Outputs

forward

Name Type Description
y torch.Tensor Output tensor of shape (batch_size, out_dim), computed as the sum of spline and residual contributions across input dimensions.

update_grid_from_samples

Name Type Description
None None Updates grid and coefficients in-place; no return value.

Usage Examples

import torch
from pytorch_forecasting.layers import KANLayer

# Create a KAN layer: 3 inputs -> 5 outputs, 5 grid intervals, order-3 splines
kan_layer = KANLayer(in_dim=3, out_dim=5, num=5, k=3)

# Forward pass with a batch of 100 samples
x = torch.normal(0, 1, size=(100, 3))
y = kan_layer(x)
print(y.shape)  # torch.Size([100, 5])

# Adaptively update the grid based on training data
train_data = torch.linspace(-3, 3, steps=500).unsqueeze(1).expand(-1, 3)
kan_layer.update_grid_from_samples(train_data)

# Forward pass after grid update
y_updated = kan_layer(x)
print(y_updated.shape)  # torch.Size([100, 5])

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment