Implementation:Sktime Pytorch forecasting TemporalFusionTransformer From Dataset
| Knowledge Sources | |
|---|---|
| Domains | Time_Series, Deep_Learning, Attention_Mechanisms |
| Last Updated | 2026-02-08 07:00 GMT |
Overview
Concrete tool for instantiating a Temporal Fusion Transformer model from a TimeSeriesDataSet provided by the pytorch-forecasting library.
Description
The TemporalFusionTransformer.from_dataset factory method creates a TFT model instance with architecture automatically configured from dataset metadata. It extracts variable names, embedding sizes, encoder length, and output dimensions from the training TimeSeriesDataSet. Users provide hyperparameters (hidden_size, attention_head_size, dropout, learning_rate) while architectural parameters (variable lists, embedding dimensions) are inferred. The default loss function is QuantileLoss with 7 quantiles.
Usage
Call this method after constructing a training TimeSeriesDataSet and before running Tuner.lr_find() or Trainer.fit(). This is the recommended way to create TFT instances as it ensures data-model consistency.
Code Reference
Source Location
- Repository: pytorch-forecasting
- File: pytorch_forecasting/models/temporal_fusion_transformer/_tft.py
- Lines: L443-474 (from_dataset), L139-442 (__init__)
Signature
class TemporalFusionTransformer(BaseModelWithCovariates):
@classmethod
def from_dataset(
cls,
dataset: TimeSeriesDataSet,
allowed_encoder_known_variable_names: list[str] = None,
**kwargs,
):
"""
Create model from dataset.
Args:
dataset: timeseries dataset
allowed_encoder_known_variable_names: List of known variables
allowed in encoder, defaults to all
**kwargs: additional arguments such as hyperparameters
(see __init__())
Returns:
TemporalFusionTransformer
"""
def __init__(
self,
hidden_size: int = 16,
lstm_layers: int = 1,
dropout: float = 0.1,
output_size: int | list[int] = 7,
loss: MultiHorizonMetric = None, # defaults to QuantileLoss()
attention_head_size: int = 4,
max_encoder_length: int = 10,
hidden_continuous_size: int = 8,
learning_rate: float = 1e-3,
reduce_on_plateau_patience: int = 1000,
share_single_variable_networks: bool = False,
causal_attention: bool = True,
logging_metrics: nn.ModuleList = None,
**kwargs,
):
Import
from pytorch_forecasting import TemporalFusionTransformer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| dataset | TimeSeriesDataSet | Yes | Training dataset to infer architecture from |
| learning_rate | float | No | Learning rate (default: 1e-3) |
| hidden_size | int | No | Main hidden size (default: 16, range: 8-512) |
| attention_head_size | int | No | Number of attention heads (default: 4) |
| dropout | float | No | Dropout rate (default: 0.1) |
| hidden_continuous_size | int | No | Size for continuous variable embeddings (default: 8) |
| output_size | int or list[int] | No | Number of output quantiles (default: 7) |
| loss | MultiHorizonMetric | No | Loss function (default: QuantileLoss()) |
| lstm_layers | int | No | Number of LSTM layers (default: 1) |
Outputs
| Name | Type | Description |
|---|---|---|
| return | TemporalFusionTransformer | Configured TFT model ready for training |
Usage Examples
Standard TFT Instantiation
from pytorch_forecasting import TemporalFusionTransformer
from pytorch_forecasting.metrics import QuantileLoss
tft = TemporalFusionTransformer.from_dataset(
training, # TimeSeriesDataSet
learning_rate=0.03,
hidden_size=16,
attention_head_size=1,
dropout=0.1,
hidden_continuous_size=8,
output_size=7,
loss=QuantileLoss(),
reduce_on_plateau_patience=4,
)
print(f"Number of parameters: {tft.size() / 1e3:.1f}k")
TFT with Custom Learning Rate from Finder
# After lr_find returns optimal_lr
tft = TemporalFusionTransformer.from_dataset(
training,
learning_rate=optimal_lr,
hidden_size=16,
attention_head_size=1,
dropout=0.1,
hidden_continuous_size=8,
loss=QuantileLoss(),
reduce_on_plateau_patience=4,
)
Related Pages
Implements Principle
Requires Environment
- Environment:Sktime_Pytorch_forecasting_Core_Python_Dependencies
- Environment:Sktime_Pytorch_forecasting_Cpflows_MQF2_Dependencies
- Environment:Sktime_Pytorch_forecasting_Matplotlib_Plotting_Dependencies