Overview
KANLayer implements a single layer of a Kolmogorov-Arnold Network (KAN) using learnable B-spline activation functions combined with a residual base function.
Description
The KANLayer class replaces traditional fixed activation functions with learnable univariate spline functions on each edge of the network. For each input-output pair, it computes a B-spline curve (via coef2curve) scaled by a trainable scale_sp parameter, adds a residual component computed by a base function (default: SiLU) scaled by scale_base, and applies a sparse connection mask. The spline grid can be adaptively updated from input samples using update_grid_from_samples, which interpolates between uniform and percentile-based grid placement controlled by grid_eps. This implementation is inspired by the pykan library.
Usage
Use KANLayer when building KAN-based architectures for time series forecasting where learnable activation functions are desired over standard fixed activations (ReLU, GELU, etc.). It is particularly useful in scenarios where the function to be learned has complex, non-standard nonlinear structure that benefits from adaptive B-spline representations.
Code Reference
Source Location
Signature
class KANLayer(nn.Module):
def __init__(
self,
in_dim=3,
out_dim=2,
num=5,
k=3,
noise_scale=0.5,
scale_base_mu=0.0,
scale_base_sigma=1.0,
scale_sp=1.0,
base_fun=None,
grid_eps=0.02,
grid_range=None,
sp_trainable=True,
sb_trainable=True,
sparse_init=False,
):
...
def forward(self, x):
...
def update_grid_from_samples(self, x):
...
Import
from pytorch_forecasting.layers import KANLayer
I/O Contract
Inputs
__init__ Parameters
| Name |
Type |
Required |
Description
|
| in_dim |
int |
No |
Input dimension. Default: 3.
|
| out_dim |
int |
No |
Output dimension. Default: 2.
|
| num |
int |
No |
Number of grid intervals (G). Default: 5.
|
| k |
int |
No |
Order of the piecewise polynomial spline. Default: 3.
|
| noise_scale |
float |
No |
Scale of noise injected at initialization. Default: 0.5.
|
| scale_base_mu |
float |
No |
Mean of the initialization distribution for the residual function scale. Default: 0.0.
|
| scale_base_sigma |
float |
No |
Standard deviation of the initialization distribution for the residual function scale. Default: 1.0.
|
| scale_sp |
float |
No |
Scale of the base spline function. Default: 1.0.
|
| base_fun |
callable |
No |
Residual function b(x). Default: torch.nn.SiLU().
|
| grid_eps |
float |
No |
Interpolation factor between uniform (1.0) and percentile-based (0.0) grid placement. Default: 0.02.
|
| grid_range |
list |
No |
Range of the grid as [min, max]. Default: [-1, 1].
|
| sp_trainable |
bool |
No |
If True, scale_sp is trainable. Default: True.
|
| sb_trainable |
bool |
No |
If True, scale_base is trainable. Default: True.
|
| sparse_init |
bool |
No |
If True, sparse initialization is applied via a connection mask. Default: False.
|
forward Parameters
| Name |
Type |
Required |
Description
|
| x |
torch.Tensor |
Yes |
Input tensor of shape (batch_size, in_dim).
|
update_grid_from_samples Parameters
| Name |
Type |
Required |
Description
|
| x |
torch.Tensor |
Yes |
Input samples of shape (num_samples, in_dim) used to adaptively update the grid.
|
Outputs
forward
| Name |
Type |
Description
|
| y |
torch.Tensor |
Output tensor of shape (batch_size, out_dim), computed as the sum of spline and residual contributions across input dimensions.
|
update_grid_from_samples
| Name |
Type |
Description
|
| None |
None |
Updates grid and coefficients in-place; no return value.
|
Usage Examples
import torch
from pytorch_forecasting.layers import KANLayer
# Create a KAN layer: 3 inputs -> 5 outputs, 5 grid intervals, order-3 splines
kan_layer = KANLayer(in_dim=3, out_dim=5, num=5, k=3)
# Forward pass with a batch of 100 samples
x = torch.normal(0, 1, size=(100, 3))
y = kan_layer(x)
print(y.shape) # torch.Size([100, 5])
# Adaptively update the grid based on training data
train_data = torch.linspace(-3, 3, steps=500).unsqueeze(1).expand(-1, 3)
kan_layer.update_grid_from_samples(train_data)
# Forward pass after grid update
y_updated = kan_layer(x)
print(y_updated.shape) # torch.Size([100, 5])
Related Pages