Implementation:Sktime Pytorch forecasting Get Stallion Data
| Knowledge Sources | |
|---|---|
| Domains | Time_Series, Data_Engineering |
| Last Updated | 2026-02-08 07:00 GMT |
Overview
Concrete tool for loading the Stallion demand dataset provided by the pytorch-forecasting library.
Description
The get_stallion_data function loads a multi-variate demand forecasting dataset containing approximately 20,000 samples across 350 time series. Each series is identified by agency and sku columns. The target variable is volume (demand), and the dataset includes temporal and covariate columns suitable for training models like the Temporal Fusion Transformer. The data is stored as a Parquet file and is downloaded from GitHub on first use, then cached locally.
Usage
Import this function when you need a realistic multi-variate demand dataset with mixed covariates for demonstrating or testing forecasting models. It is the canonical dataset used in the TFT Demand Forecasting and TFT Hyperparameter Optimization workflows. After loading, users typically apply pandas feature engineering (log transforms, time index creation, rolling averages) before constructing a TimeSeriesDataSet.
Code Reference
Source Location
- Repository: pytorch-forecasting
- File: pytorch_forecasting/data/examples.py
- Lines: L36-50
Signature
def get_stallion_data() -> pd.DataFrame:
"""
Demand data with covariates.
~20k samples of 350 timeseries. Important columns:
- agency and sku identify timeseries
- volume is the demand
- date is the month of the demand
Returns:
pd.DataFrame: data
"""
Import
from pytorch_forecasting.data.examples import get_stallion_data
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| (none) | — | — | Function takes no arguments; downloads data on first call |
Outputs
| Name | Type | Description |
|---|---|---|
| return | pd.DataFrame | DataFrame with columns: agency, sku, volume, date, plus covariates (~20k rows, 350 series) |
Usage Examples
Basic Data Loading
from pytorch_forecasting.data.examples import get_stallion_data
import pandas as pd
# 1. Load the stallion demand dataset
data = get_stallion_data()
# 2. Add time index (required for TimeSeriesDataSet)
data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month
data["time_idx"] -= data["time_idx"].min()
# 3. Add log-transformed volume
data["log_volume"] = data["volume"].apply(lambda x: x if x > 0 else 1).transform("log")
# 4. Add rolling average features per group
data["avg_volume_by_sku"] = data.groupby(
["time_idx", "sku"], observed=True
)["volume"].transform("mean")
data["avg_volume_by_agency"] = data.groupby(
["time_idx", "agency"], observed=True
)["volume"].transform("mean")
print(f"Loaded {len(data)} rows, {data['agency'].nunique() * data['sku'].nunique()} series")
print(f"Columns: {list(data.columns)}")