Implementation:Sktime Pytorch forecasting Samformer
| Knowledge Sources | |
|---|---|
| Domains | Time_Series, Forecasting, Deep_Learning |
| Last Updated | 2026-02-08 08:00 GMT |
Overview
Samformer is a SAM-optimized Transformer model for channel-independent time series forecasting, extending BaseModel with scaled dot-product attention and optional Reverse Instance Normalization (RevIN).
Description
Samformer implements the architecture from "Samformer: Unlocking the Potential of Transformers in Time Series Forecasting with Sharpness-Aware Minimization and Channel-Wise Attention." The model computes keys, queries, and values via linear projections over the encoder sequence length, applies scaled dot-product attention across channels, and uses a linear forecaster to project from encoder length to prediction length. It optionally applies RevIN normalization and supports quantile predictions. The model is designed for channel-independent forecasting where each variable channel is processed independently through the attention mechanism.
Usage
Use Samformer when you need a lightweight Transformer-based forecasting model that benefits from Sharpness-Aware Minimization (SAM) optimization. It is particularly well-suited for multivariate time series where channel-independent processing is desired, and when RevIN normalization helps address distribution shift between training and inference.
Code Reference
Source Location
- Repository: Sktime_Pytorch_forecasting
- File: pytorch_forecasting/models/samformer/_samformer_v2.py
- Lines: 1-186
Signature
class Samformer(BaseModel):
def __init__(
self,
loss: nn.Module,
# specific params
hidden_size: int,
use_revin: bool,
# out_channels has to be 1, due to lack of MultiLoss support in v2.
out_channels: int | list[int] | None = 1,
persistence_weight: float = 0.0,
logging_metrics: list[nn.Module] | None = None,
optimizer: Optimizer | str | None = "adam",
optimizer_params: dict | None = None,
lr_scheduler: str | None = None,
lr_scheduler_params: dict | None = None,
metadata: dict | None = None,
**kwargs,
):
Import
from pytorch_forecasting.models.samformer import Samformer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| loss | nn.Module | Yes | Loss function module used for training; may contain quantiles attribute for quantile regression |
| hidden_size | int | Yes | First embedding size of the model (referred to as 'r' in the paper) |
| use_revin | bool | Yes | Whether to apply Reverse Instance Normalization to the input |
| out_channels | int or list[int] or None | No | Number of output variables to predict; must be 1 due to lack of MultiLoss support in v2; defaults to 1 |
| persistence_weight | float | No | Weight for persistence baseline; defaults to 0.0 |
| logging_metrics | list[nn.Module] or None | No | List of metrics to log during training; defaults to None |
| optimizer | Optimizer or str or None | No | Optimizer class or string identifier; defaults to "adam" |
| optimizer_params | dict or None | No | Additional parameters for the optimizer; defaults to None |
| lr_scheduler | str or None | No | Learning rate scheduler string identifier; defaults to None |
| lr_scheduler_params | dict or None | No | Additional parameters for the LR scheduler; defaults to None |
| metadata | dict or None | No | Dictionary containing max_encoder_length, max_prediction_length, and encoder_cont keys |
Outputs
| Name | Type | Description |
|---|---|---|
| forward(x) | dict[str, torch.Tensor] | Dictionary with key "prediction" containing a tensor of shape (batch_size, max_prediction_length, n_quantiles) |
Usage Examples
from pytorch_forecasting.models.samformer import Samformer
from torch import nn
# Define metadata from dataset
metadata = {
"max_encoder_length": 96,
"max_prediction_length": 24,
"encoder_cont": 6,
}
# Instantiate the Samformer model
model = Samformer(
loss=nn.MSELoss(),
hidden_size=512,
use_revin=True,
out_channels=1,
persistence_weight=0.0,
optimizer="adam",
optimizer_params={"lr": 1e-3},
metadata=metadata,
)