Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Sktime Pytorch forecasting KAN Utils

From Leeroopedia


Knowledge Sources
Domains Time_Series, Forecasting, Deep_Learning
Last Updated 2026-02-08 08:00 GMT

Overview

KAN_Utils provides a collection of B-spline utility functions used by KANLayer for evaluating spline bases, converting between coefficients and curves, extending grids, and generating sparse connection masks.

Description

This module contains five core utility functions that support the Kolmogorov-Arnold Network layer implementation. b_batch recursively evaluates B-spline basis functions of arbitrary order using the Cox-de Boor recursion formula. coef2curve converts B-spline coefficients to curve values via Einstein summation over the basis. curve2coef estimates spline coefficients from sample data using batched least squares. extend_grid pads a grid tensor with equally-spaced points on both ends to support B-spline boundary conditions. sparse_mask generates a binary connection mask between input and output dimensions based on nearest-neighbor spatial proximity.

Usage

These utility functions are used internally by KANLayer and should be employed when implementing custom KAN architectures or when direct manipulation of B-spline grids and coefficients is needed. They are not typically called directly in high-level model construction.

Code Reference

Source Location

Signatures

b_batch

def b_batch(x, grid, k=0):
    ...

coef2curve

def coef2curve(x_eval, grid, coef, k):
    ...

curve2coef

def curve2coef(x_eval, y_eval, grid, k):
    ...

extend_grid

def extend_grid(grid, k_extend=0):
    ...

sparse_mask

def sparse_mask(in_dim, out_dim):
    ...

Import

from pytorch_forecasting.layers._kan._utils import (
    b_batch,
    coef2curve,
    curve2coef,
    extend_grid,
    sparse_mask,
)

I/O Contract

b_batch

Inputs

Name Type Required Description
x torch.Tensor Yes 2D tensor of inputs, shape (number of splines, number of samples).
grid torch.Tensor Yes 2D tensor of grids, shape (number of splines, number of grid points).
k int No Piecewise polynomial order of splines. Default: 0.

Outputs

Name Type Description
value torch.Tensor 3D tensor of B-spline basis values, shape (batch, in_dim, G+k).

coef2curve

Inputs

Name Type Required Description
x_eval torch.Tensor Yes 2D tensor of shape (batch, in_dim).
grid torch.Tensor Yes 2D tensor of shape (in_dim, G+2k).
coef torch.Tensor Yes 3D tensor of shape (in_dim, out_dim, G+k).
k int Yes Piecewise polynomial order of splines.

Outputs

Name Type Description
y_eval torch.Tensor 3D tensor of shape (batch, in_dim, out_dim).

curve2coef

Inputs

Name Type Required Description
x_eval torch.Tensor Yes 2D tensor of shape (batch, in_dim).
y_eval torch.Tensor Yes 3D tensor of shape (batch, in_dim, out_dim).
grid torch.Tensor Yes 2D tensor of shape (in_dim, grid + 2 * k).
k int Yes Spline order.

Outputs

Name Type Description
coef torch.Tensor 3D tensor of shape (in_dim, out_dim, G+k).

extend_grid

Inputs

Name Type Required Description
grid torch.Tensor Yes Grid tensor of shape (in_dim, grid_points).
k_extend int No Number of points to extend on both ends. Default: 0.

Outputs

Name Type Description
grid torch.Tensor Extended grid of shape (in_dim, grid_points + 2 * k_extend).

sparse_mask

Inputs

Name Type Required Description
in_dim int Yes Number of input units.
out_dim int Yes Number of output units.

Outputs

Name Type Description
mask torch.Tensor Sparse binary mask of shape (in_dim, out_dim).

Usage Examples

import torch
from pytorch_forecasting.layers._kan._utils import (
    b_batch,
    coef2curve,
    curve2coef,
    extend_grid,
    sparse_mask,
)

# Evaluate B-spline basis functions
x = torch.rand(100, 2)
grid = torch.linspace(-1, 1, steps=11)[None, :].expand(2, 11)
basis = b_batch(x, grid, k=3)
print(basis.shape)  # torch.Size([100, 2, 7])

# Extend a grid by 3 points on each end
grid_small = torch.linspace(0, 1, steps=6)[None, :].expand(2, 6)
grid_extended = extend_grid(grid_small, k_extend=3)
print(grid_extended.shape)  # torch.Size([2, 12])

# Generate a sparse connection mask
mask = sparse_mask(in_dim=5, out_dim=3)
print(mask.shape)  # torch.Size([5, 3])
print(mask)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment