Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Sktime Pytorch forecasting Samformer

From Leeroopedia


Knowledge Sources
Domains Time_Series, Forecasting, Deep_Learning
Last Updated 2026-02-08 08:00 GMT

Overview

Samformer is a SAM-optimized Transformer model for channel-independent time series forecasting, extending BaseModel with scaled dot-product attention and optional Reverse Instance Normalization (RevIN).

Description

Samformer implements the architecture from "Samformer: Unlocking the Potential of Transformers in Time Series Forecasting with Sharpness-Aware Minimization and Channel-Wise Attention." The model computes keys, queries, and values via linear projections over the encoder sequence length, applies scaled dot-product attention across channels, and uses a linear forecaster to project from encoder length to prediction length. It optionally applies RevIN normalization and supports quantile predictions. The model is designed for channel-independent forecasting where each variable channel is processed independently through the attention mechanism.

Usage

Use Samformer when you need a lightweight Transformer-based forecasting model that benefits from Sharpness-Aware Minimization (SAM) optimization. It is particularly well-suited for multivariate time series where channel-independent processing is desired, and when RevIN normalization helps address distribution shift between training and inference.

Code Reference

Source Location

Signature

class Samformer(BaseModel):
    def __init__(
        self,
        loss: nn.Module,
        # specific params
        hidden_size: int,
        use_revin: bool,
        # out_channels has to be 1, due to lack of MultiLoss support in v2.
        out_channels: int | list[int] | None = 1,
        persistence_weight: float = 0.0,
        logging_metrics: list[nn.Module] | None = None,
        optimizer: Optimizer | str | None = "adam",
        optimizer_params: dict | None = None,
        lr_scheduler: str | None = None,
        lr_scheduler_params: dict | None = None,
        metadata: dict | None = None,
        **kwargs,
    ):

Import

from pytorch_forecasting.models.samformer import Samformer

I/O Contract

Inputs

Name Type Required Description
loss nn.Module Yes Loss function module used for training; may contain quantiles attribute for quantile regression
hidden_size int Yes First embedding size of the model (referred to as 'r' in the paper)
use_revin bool Yes Whether to apply Reverse Instance Normalization to the input
out_channels int or list[int] or None No Number of output variables to predict; must be 1 due to lack of MultiLoss support in v2; defaults to 1
persistence_weight float No Weight for persistence baseline; defaults to 0.0
logging_metrics list[nn.Module] or None No List of metrics to log during training; defaults to None
optimizer Optimizer or str or None No Optimizer class or string identifier; defaults to "adam"
optimizer_params dict or None No Additional parameters for the optimizer; defaults to None
lr_scheduler str or None No Learning rate scheduler string identifier; defaults to None
lr_scheduler_params dict or None No Additional parameters for the LR scheduler; defaults to None
metadata dict or None No Dictionary containing max_encoder_length, max_prediction_length, and encoder_cont keys

Outputs

Name Type Description
forward(x) dict[str, torch.Tensor] Dictionary with key "prediction" containing a tensor of shape (batch_size, max_prediction_length, n_quantiles)

Usage Examples

from pytorch_forecasting.models.samformer import Samformer
from torch import nn

# Define metadata from dataset
metadata = {
    "max_encoder_length": 96,
    "max_prediction_length": 24,
    "encoder_cont": 6,
}

# Instantiate the Samformer model
model = Samformer(
    loss=nn.MSELoss(),
    hidden_size=512,
    use_revin=True,
    out_channels=1,
    persistence_weight=0.0,
    optimizer="adam",
    optimizer_params={"lr": 1e-3},
    metadata=metadata,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment