Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Workflow:Sktime Pytorch forecasting TFT Demand Forecasting

From Leeroopedia


Knowledge Sources
Domains Time_Series, Demand_Forecasting, Deep_Learning
Last Updated 2026-02-08 07:00 GMT

Overview

End-to-end process for multi-horizon demand forecasting using the Temporal Fusion Transformer (TFT) with mixed covariates on tabular time series data.

Description

This workflow covers the complete pipeline for training a Temporal Fusion Transformer model to predict future demand (or any continuous target) across multiple time series. The TFT architecture combines high-performance multi-horizon forecasting with interpretable insights through variable importance and attention weights. The process starts from a pandas DataFrame, constructs a specialized time series dataset with metadata about covariates and normalizers, builds the model architecture automatically from dataset parameters, finds an optimal learning rate, and trains the model with early stopping. The trained model produces quantile predictions that capture forecast uncertainty.

Key capabilities:

  • Handles static, time-varying known, and time-varying unknown covariates
  • Produces multi-horizon probabilistic (quantile) forecasts
  • Built-in interpretability via variable selection and attention mechanisms
  • Automatic architecture inference from dataset metadata

Usage

Execute this workflow when you have a panel (multi-entity) time series dataset in a pandas DataFrame with multiple covariates (e.g., product demand by store/SKU with price, promotions, holidays) and need to produce multi-step-ahead probabilistic forecasts. This is the recommended starting workflow for most real-world forecasting problems where covariate information is available.

Execution Steps

Step 1: Data Loading and Feature Engineering

Load the raw tabular data into a pandas DataFrame and create the features required by the forecasting pipeline. This includes constructing a monotonically increasing integer time index, engineering derived features (e.g., log transforms, rolling aggregates, calendar features), and encoding categorical variables. The DataFrame must contain columns for the target variable, time series group identifiers, the time index, and any covariates.

Key considerations:

  • The time index must be a monotonically increasing integer column (not a datetime)
  • Group identifiers should uniquely identify each time series when combined
  • Categorical variables should be encoded as pandas category dtype or string
  • Derived features like group-level averages help the model learn cross-series patterns

Step 2: TimeSeriesDataSet Construction

Create a TimeSeriesDataSet instance that wraps the pandas DataFrame with metadata describing the forecasting problem. This involves specifying which columns are targets, group identifiers, static categoricals/reals, time-varying known categoricals/reals, and time-varying unknown reals. Configure the encoder and prediction lengths, and set up target normalization (e.g., GroupNormalizer with softplus transformation for non-negative targets).

Key considerations:

  • Split data into training set using a temporal cutoff (all data before cutoff)
  • Choose encoder length (history window) and prediction length (forecast horizon)
  • Select appropriate normalizer: GroupNormalizer for per-group normalization, EncoderNormalizer for per-sample
  • Use variable_groups to bundle related categorical columns (e.g., multiple holiday flags)
  • Enable add_relative_time_idx, add_target_scales, and add_encoder_length for the TFT

Step 3: Validation Dataset and DataLoader Creation

Create a validation TimeSeriesDataSet from the training dataset using the from_dataset class method, which inherits all encoding and normalization parameters. Then convert both datasets to PyTorch DataLoaders with appropriate batch sizes and worker configuration.

Key considerations:

  • Use TimeSeriesDataSet.from_dataset() to ensure consistent preprocessing between train and validation
  • Set stop_randomization=True for the validation set to get deterministic evaluation
  • The predict=True flag selects only the last available prediction window per time series
  • Batch size affects memory usage and training dynamics; 64-128 is typical

Step 4: Trainer Configuration

Configure the PyTorch Lightning Trainer with callbacks for early stopping (monitoring validation loss), learning rate monitoring, and TensorBoard logging. Set the accelerator, gradient clipping value, maximum epochs, and optional batch limits for faster iteration during development.

Key considerations:

  • EarlyStopping on val_loss prevents overfitting; patience of 1-10 epochs is typical
  • Gradient clipping (0.01-1.0) stabilizes training for transformer architectures
  • limit_train_batches can be used to reduce epoch length during experimentation
  • TensorBoardLogger enables visual monitoring of training progress

Step 5: Model Instantiation

Create the TFT model using the from_dataset class method, which automatically infers the architecture from the dataset metadata (number of embeddings, continuous variables, etc.). Specify hyperparameters like hidden_size, attention_head_size, dropout, and the loss function (QuantileLoss for probabilistic forecasts).

Key considerations:

  • from_dataset() infers embedding dimensions, variable counts, and output transformer from the dataset
  • QuantileLoss produces quantile forecasts at specified quantiles (default: 0.02 to 0.98)
  • hidden_size controls model capacity (16-256 typical range)
  • output_size must match the number of quantiles in QuantileLoss (default: 7)
  • log_interval controls how often predictions are logged to TensorBoard

Step 6: Learning Rate Finding

Use the PyTorch Lightning Tuner to run the learning rate finder, which trains the model for a short period with exponentially increasing learning rates and identifies the optimal rate where the loss decreases most steeply.

Key considerations:

  • The lr_find method returns a suggestion; always visually verify with the plot
  • Set early_stop_threshold high enough (e.g., 1000) to explore a wide range
  • The suggested learning rate is typically just before the loss curve minimum
  • If the suggestion seems off, manually select from the plot where loss drops fastest

Step 7: Model Training

Fit the model on the training data with early stopping monitoring validation loss. The Lightning Trainer handles the training loop, gradient updates, logging, and checkpoint management automatically.

Key considerations:

  • Training automatically logs to TensorBoard; monitor train_loss and val_loss convergence
  • Early stopping terminates training when val_loss stops improving
  • The best model checkpoint is saved automatically by Lightning
  • After training, load the best checkpoint for prediction using the logged checkpoint path

Execution Diagram

GitHub URL

Workflow Repository