Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Workflow:Sktime Pytorch forecasting TFT Hyperparameter Optimization

From Leeroopedia
Revision as of 11:05, 16 February 2026 by Admin (talk | contribs) (Auto-imported from workflows/Sktime_Pytorch_forecasting_TFT_Hyperparameter_Optimization.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Time_Series, Hyperparameter_Tuning, AutoML
Last Updated 2026-02-08 07:00 GMT

Overview

End-to-end process for automated hyperparameter optimization of the Temporal Fusion Transformer using Optuna with integrated learning rate finding and successive halving pruning.

Description

This workflow extends the standard TFT training pipeline with automated hyperparameter search. It uses Optuna to explore the hyperparameter space (hidden size, attention heads, dropout, gradient clipping, learning rate) while leveraging the PyTorch Lightning learning rate finder within each trial to automatically set the optimal learning rate. Successive halving pruning terminates unpromising trials early, making the search efficient. The process starts from pre-built DataLoaders (assuming the data preparation from the TFT Demand Forecasting workflow is complete), runs multiple training trials with different hyperparameter configurations, and returns the Optuna study with the best trial results.

Key capabilities:

  • Automated search over TFT architecture and optimization hyperparameters
  • Integrated learning rate finding per trial using smoothed loss curve analysis
  • Successive halving pruning for efficient trial elimination
  • TensorBoard logging per trial for detailed comparison
  • Model checkpointing per trial for best model recovery

Usage

Execute this workflow after completing the data preparation steps of the TFT Demand Forecasting workflow, when you want to find the best TFT hyperparameters for your specific dataset. This is recommended before production deployment when default hyperparameters yield unsatisfactory results, or when exploring how the model responds to different architectural choices.

Execution Steps

Step 1: Data Preparation and DataLoader Construction

Prepare the TimeSeriesDataSet and DataLoaders following Steps 1-3 of the TFT Demand Forecasting workflow. The optimize_hyperparameters function expects pre-built PyTorch DataLoaders backed by TimeSeriesDataSet instances.

Key considerations:

  • DataLoaders must be backed by TimeSeriesDataSet (enforced by assertion)
  • Use the same data preparation as your eventual production pipeline
  • Consider using limit_train_batches in trainer_kwargs for faster trial evaluation during search

Step 2: Search Space Configuration

Define the hyperparameter search ranges for the optimization. The built-in optimize_hyperparameters function provides sensible defaults but allows customization of ranges for hidden_size, hidden_continuous_size, attention_head_size, dropout, gradient_clip_val, and learning_rate.

Key considerations:

  • hidden_size_range: (16, 265) controls model capacity; larger values increase memory and training time
  • attention_head_size_range: (1, 4) affects the multi-head attention mechanism
  • dropout_range: (0.1, 0.3) controls regularization strength
  • gradient_clip_val_range: (0.01, 1.0) is sampled in log-uniform space
  • learning_rate_range: (1e-5, 1.0) is used when learning rate finder is disabled
  • Set n_trials and timeout to control the total search budget

Step 3: Optimization Execution

Call the optimize_hyperparameters function, which creates an Optuna study and runs the objective function for each trial. Each trial creates a TFT model with sampled hyperparameters, optionally runs the learning rate finder to set the optimal learning rate, trains the model, and reports the validation loss.

What happens per trial:

  • Hyperparameters are sampled from the configured ranges
  • A TFT model is instantiated via from_dataset with the sampled parameters
  • If use_learning_rate_finder=True, a short LR sweep determines the optimal learning rate using LOWESS smoothing on the loss curve
  • The model is trained for up to max_epochs with early stopping
  • Validation loss is reported to Optuna for trial comparison
  • Successive halving pruning may terminate the trial early if it underperforms

Key considerations:

  • use_learning_rate_finder=True (default) adds overhead per trial but generally improves results
  • The PyTorchLightningPruningCallback integrates Optuna pruning with Lightning training
  • Each trial saves model checkpoints and TensorBoard logs to separate directories
  • Set verbose=0 to suppress per-trial output for large searches

Step 4: Results Analysis and Model Selection

After optimization completes, analyze the Optuna study to identify the best hyperparameters. The study object contains trial results, parameter importance, and optimization history. Save the study for later analysis and use the best parameters to train a final production model.

Key considerations:

  • study.best_trial contains the best hyperparameters and validation loss
  • study.best_params returns a dictionary of optimal hyperparameter values
  • Save the study with pickle for reproducibility and later analysis
  • Optuna provides visualization tools (plot_optimization_history, plot_param_importances)
  • Train a final model with the best parameters on the full dataset (no validation holdout) for production

Execution Diagram

GitHub URL

Workflow Repository