Implementation:Sktime Pytorch forecasting TFT V2
| Knowledge Sources | |
|---|---|
| Domains | Time_Series, Forecasting, Deep_Learning |
| Last Updated | 2026-02-08 08:00 GMT |
Overview
TFT V2 (Temporal Fusion Transformer V2) is a V2 implementation of the Temporal Fusion Transformer model built on the new BaseModel V2 data pipeline.
Description
The TFT class extends BaseModel (from the V2 base model pipeline) and implements a simplified Temporal Fusion Transformer architecture. It combines variable selection networks for encoder and decoder inputs, LSTM layers for sequential processing, multi-head self-attention for capturing long-range dependencies, and static context integration. The model processes encoder categorical/continuous features, decoder categorical/continuous features, and static features through variable selection gates, then uses LSTM encoder-decoder layers followed by multi-head attention to produce forecasts. This is an experimental V2 implementation designed to work with the new data pipeline.
Usage
Use TFT V2 when you need an attention-based time series forecasting model that can handle heterogeneous inputs including static, encoder, and decoder features. It is suitable for complex forecasting tasks where variable importance and temporal attention patterns matter. Note that this is an experimental implementation on the V2 pipeline.
Code Reference
Source Location
- Repository: Sktime_Pytorch_forecasting
- File: pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py
- Lines: 1-261
Signature
class TFT(BaseModel):
def __init__(
self,
loss: nn.Module,
logging_metrics: list[nn.Module] | None = None,
optimizer: Optimizer | str | None = "adam",
optimizer_params: dict | None = None,
lr_scheduler: str | None = None,
lr_scheduler_params: dict | None = None,
hidden_size: int = 64,
num_layers: int = 2,
attention_head_size: int = 4,
dropout: float = 0.1,
metadata: dict | None = None,
output_size: int = 1,
):
Import
from pytorch_forecasting.models.temporal_fusion_transformer._tft_v2 import TFT
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| loss | nn.Module | Yes | Loss function for training. |
| logging_metrics | None | No | Metrics to log during training. Defaults to None. |
| optimizer | str | None | No | Optimizer for training. Defaults to "adam". |
| optimizer_params | None | No | Optimizer parameters. Defaults to None. |
| lr_scheduler | None | No | Learning rate scheduler name. Defaults to None. |
| lr_scheduler_params | None | No | LR scheduler parameters. Defaults to None. |
| hidden_size | int | No | Size of hidden layers and LSTM hidden state. Defaults to 64. |
| num_layers | int | No | Number of LSTM layers. Defaults to 2. |
| attention_head_size | int | No | Number of attention heads. Defaults to 4. |
| dropout | float | No | Dropout probability. Defaults to 0.1. |
| metadata | None | No | Dataset metadata containing max_encoder_length, max_prediction_length, encoder_cont, encoder_cat, decoder_cont, decoder_cat, and static feature dimensions. |
| output_size | int | No | Number of output features. Defaults to 1. |
Forward Input Dictionary
| Name | Type | Required | Description |
|---|---|---|---|
| encoder_cont | torch.Tensor | Yes | Continuous encoder features of shape (batch, encoder_length, n_cont). |
| encoder_cat | torch.Tensor | No | Categorical encoder features of shape (batch, encoder_length, n_cat). |
| decoder_cont | torch.Tensor | No | Continuous decoder features of shape (batch, prediction_length, n_cont). |
| decoder_cat | torch.Tensor | No | Categorical decoder features of shape (batch, prediction_length, n_cat). |
| static_categorical_features | torch.Tensor | No | Static categorical features. |
| static_continuous_features | torch.Tensor | No | Static continuous features. |
Outputs
| Name | Type | Description |
|---|---|---|
| prediction | torch.Tensor | Forecast output of shape (batch_size, prediction_length, output_size). |
Usage Examples
from pytorch_forecasting.models.temporal_fusion_transformer._tft_v2 import TFT
import torch.nn as nn
metadata = {
"max_encoder_length": 96,
"max_prediction_length": 24,
"encoder_cont": 5,
"encoder_cat": 2,
"decoder_cont": 3,
"decoder_cat": 1,
"static_categorical_features": 2,
"static_continuous_features": 1,
}
model = TFT(
loss=nn.MSELoss(),
hidden_size=64,
num_layers=2,
attention_head_size=4,
dropout=0.1,
metadata=metadata,
output_size=1,
)
# Forward pass
output = model(x_batch)
predictions = output["prediction"]
Related Pages
- Principle:Sktime_Pytorch_forecasting_TFT_V2_Architecture
- Sktime_Pytorch_forecasting_DLinear_V2 - Simpler V2 model without attention
- Sktime_Pytorch_forecasting_TiDE_V2 - Alternative V2 dense encoder-decoder model