Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Sktime Pytorch forecasting TFT V2

From Leeroopedia


Knowledge Sources
Domains Time_Series, Forecasting, Deep_Learning
Last Updated 2026-02-08 08:00 GMT

Overview

TFT V2 (Temporal Fusion Transformer V2) is a V2 implementation of the Temporal Fusion Transformer model built on the new BaseModel V2 data pipeline.

Description

The TFT class extends BaseModel (from the V2 base model pipeline) and implements a simplified Temporal Fusion Transformer architecture. It combines variable selection networks for encoder and decoder inputs, LSTM layers for sequential processing, multi-head self-attention for capturing long-range dependencies, and static context integration. The model processes encoder categorical/continuous features, decoder categorical/continuous features, and static features through variable selection gates, then uses LSTM encoder-decoder layers followed by multi-head attention to produce forecasts. This is an experimental V2 implementation designed to work with the new data pipeline.

Usage

Use TFT V2 when you need an attention-based time series forecasting model that can handle heterogeneous inputs including static, encoder, and decoder features. It is suitable for complex forecasting tasks where variable importance and temporal attention patterns matter. Note that this is an experimental implementation on the V2 pipeline.

Code Reference

Source Location

Signature

class TFT(BaseModel):
    def __init__(
        self,
        loss: nn.Module,
        logging_metrics: list[nn.Module] | None = None,
        optimizer: Optimizer | str | None = "adam",
        optimizer_params: dict | None = None,
        lr_scheduler: str | None = None,
        lr_scheduler_params: dict | None = None,
        hidden_size: int = 64,
        num_layers: int = 2,
        attention_head_size: int = 4,
        dropout: float = 0.1,
        metadata: dict | None = None,
        output_size: int = 1,
    ):

Import

from pytorch_forecasting.models.temporal_fusion_transformer._tft_v2 import TFT

I/O Contract

Inputs

Name Type Required Description
loss nn.Module Yes Loss function for training.
logging_metrics None No Metrics to log during training. Defaults to None.
optimizer str | None No Optimizer for training. Defaults to "adam".
optimizer_params None No Optimizer parameters. Defaults to None.
lr_scheduler None No Learning rate scheduler name. Defaults to None.
lr_scheduler_params None No LR scheduler parameters. Defaults to None.
hidden_size int No Size of hidden layers and LSTM hidden state. Defaults to 64.
num_layers int No Number of LSTM layers. Defaults to 2.
attention_head_size int No Number of attention heads. Defaults to 4.
dropout float No Dropout probability. Defaults to 0.1.
metadata None No Dataset metadata containing max_encoder_length, max_prediction_length, encoder_cont, encoder_cat, decoder_cont, decoder_cat, and static feature dimensions.
output_size int No Number of output features. Defaults to 1.

Forward Input Dictionary

Name Type Required Description
encoder_cont torch.Tensor Yes Continuous encoder features of shape (batch, encoder_length, n_cont).
encoder_cat torch.Tensor No Categorical encoder features of shape (batch, encoder_length, n_cat).
decoder_cont torch.Tensor No Continuous decoder features of shape (batch, prediction_length, n_cont).
decoder_cat torch.Tensor No Categorical decoder features of shape (batch, prediction_length, n_cat).
static_categorical_features torch.Tensor No Static categorical features.
static_continuous_features torch.Tensor No Static continuous features.

Outputs

Name Type Description
prediction torch.Tensor Forecast output of shape (batch_size, prediction_length, output_size).

Usage Examples

from pytorch_forecasting.models.temporal_fusion_transformer._tft_v2 import TFT
import torch.nn as nn

metadata = {
    "max_encoder_length": 96,
    "max_prediction_length": 24,
    "encoder_cont": 5,
    "encoder_cat": 2,
    "decoder_cont": 3,
    "decoder_cat": 1,
    "static_categorical_features": 2,
    "static_continuous_features": 1,
}

model = TFT(
    loss=nn.MSELoss(),
    hidden_size=64,
    num_layers=2,
    attention_head_size=4,
    dropout=0.1,
    metadata=metadata,
    output_size=1,
)

# Forward pass
output = model(x_batch)
predictions = output["prediction"]

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment