Overview
KAN_Utils provides a collection of B-spline utility functions used by KANLayer for evaluating spline bases, converting between coefficients and curves, extending grids, and generating sparse connection masks.
Description
This module contains five core utility functions that support the Kolmogorov-Arnold Network layer implementation. b_batch recursively evaluates B-spline basis functions of arbitrary order using the Cox-de Boor recursion formula. coef2curve converts B-spline coefficients to curve values via Einstein summation over the basis. curve2coef estimates spline coefficients from sample data using batched least squares. extend_grid pads a grid tensor with equally-spaced points on both ends to support B-spline boundary conditions. sparse_mask generates a binary connection mask between input and output dimensions based on nearest-neighbor spatial proximity.
Usage
These utility functions are used internally by KANLayer and should be employed when implementing custom KAN architectures or when direct manipulation of B-spline grids and coefficients is needed. They are not typically called directly in high-level model construction.
Code Reference
Source Location
Signatures
b_batch
def b_batch(x, grid, k=0):
...
coef2curve
def coef2curve(x_eval, grid, coef, k):
...
curve2coef
def curve2coef(x_eval, y_eval, grid, k):
...
extend_grid
def extend_grid(grid, k_extend=0):
...
sparse_mask
def sparse_mask(in_dim, out_dim):
...
Import
from pytorch_forecasting.layers._kan._utils import (
b_batch,
coef2curve,
curve2coef,
extend_grid,
sparse_mask,
)
I/O Contract
b_batch
Inputs
| Name |
Type |
Required |
Description
|
| x |
torch.Tensor |
Yes |
2D tensor of inputs, shape (number of splines, number of samples).
|
| grid |
torch.Tensor |
Yes |
2D tensor of grids, shape (number of splines, number of grid points).
|
| k |
int |
No |
Piecewise polynomial order of splines. Default: 0.
|
Outputs
| Name |
Type |
Description
|
| value |
torch.Tensor |
3D tensor of B-spline basis values, shape (batch, in_dim, G+k).
|
coef2curve
Inputs
| Name |
Type |
Required |
Description
|
| x_eval |
torch.Tensor |
Yes |
2D tensor of shape (batch, in_dim).
|
| grid |
torch.Tensor |
Yes |
2D tensor of shape (in_dim, G+2k).
|
| coef |
torch.Tensor |
Yes |
3D tensor of shape (in_dim, out_dim, G+k).
|
| k |
int |
Yes |
Piecewise polynomial order of splines.
|
Outputs
| Name |
Type |
Description
|
| y_eval |
torch.Tensor |
3D tensor of shape (batch, in_dim, out_dim).
|
curve2coef
Inputs
| Name |
Type |
Required |
Description
|
| x_eval |
torch.Tensor |
Yes |
2D tensor of shape (batch, in_dim).
|
| y_eval |
torch.Tensor |
Yes |
3D tensor of shape (batch, in_dim, out_dim).
|
| grid |
torch.Tensor |
Yes |
2D tensor of shape (in_dim, grid + 2 * k).
|
| k |
int |
Yes |
Spline order.
|
Outputs
| Name |
Type |
Description
|
| coef |
torch.Tensor |
3D tensor of shape (in_dim, out_dim, G+k).
|
extend_grid
Inputs
| Name |
Type |
Required |
Description
|
| grid |
torch.Tensor |
Yes |
Grid tensor of shape (in_dim, grid_points).
|
| k_extend |
int |
No |
Number of points to extend on both ends. Default: 0.
|
Outputs
| Name |
Type |
Description
|
| grid |
torch.Tensor |
Extended grid of shape (in_dim, grid_points + 2 * k_extend).
|
sparse_mask
Inputs
| Name |
Type |
Required |
Description
|
| in_dim |
int |
Yes |
Number of input units.
|
| out_dim |
int |
Yes |
Number of output units.
|
Outputs
| Name |
Type |
Description
|
| mask |
torch.Tensor |
Sparse binary mask of shape (in_dim, out_dim).
|
Usage Examples
import torch
from pytorch_forecasting.layers._kan._utils import (
b_batch,
coef2curve,
curve2coef,
extend_grid,
sparse_mask,
)
# Evaluate B-spline basis functions
x = torch.rand(100, 2)
grid = torch.linspace(-1, 1, steps=11)[None, :].expand(2, 11)
basis = b_batch(x, grid, k=3)
print(basis.shape) # torch.Size([100, 2, 7])
# Extend a grid by 3 points on each end
grid_small = torch.linspace(0, 1, steps=6)[None, :].expand(2, 6)
grid_extended = extend_grid(grid_small, k_extend=3)
print(grid_extended.shape) # torch.Size([2, 12])
# Generate a sparse connection mask
mask = sparse_mask(in_dim=5, out_dim=3)
print(mask.shape) # torch.Size([5, 3])
print(mask)
Related Pages