Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Sktime Pytorch forecasting Get Stallion Data

From Leeroopedia


Knowledge Sources
Domains Time_Series, Data_Engineering
Last Updated 2026-02-08 07:00 GMT

Overview

Concrete tool for loading the Stallion demand dataset provided by the pytorch-forecasting library.

Description

The get_stallion_data function loads a multi-variate demand forecasting dataset containing approximately 20,000 samples across 350 time series. Each series is identified by agency and sku columns. The target variable is volume (demand), and the dataset includes temporal and covariate columns suitable for training models like the Temporal Fusion Transformer. The data is stored as a Parquet file and is downloaded from GitHub on first use, then cached locally.

Usage

Import this function when you need a realistic multi-variate demand dataset with mixed covariates for demonstrating or testing forecasting models. It is the canonical dataset used in the TFT Demand Forecasting and TFT Hyperparameter Optimization workflows. After loading, users typically apply pandas feature engineering (log transforms, time index creation, rolling averages) before constructing a TimeSeriesDataSet.

Code Reference

Source Location

Signature

def get_stallion_data() -> pd.DataFrame:
    """
    Demand data with covariates.

    ~20k samples of 350 timeseries. Important columns:
    - agency and sku identify timeseries
    - volume is the demand
    - date is the month of the demand

    Returns:
        pd.DataFrame: data
    """

Import

from pytorch_forecasting.data.examples import get_stallion_data

I/O Contract

Inputs

Name Type Required Description
(none) Function takes no arguments; downloads data on first call

Outputs

Name Type Description
return pd.DataFrame DataFrame with columns: agency, sku, volume, date, plus covariates (~20k rows, 350 series)

Usage Examples

Basic Data Loading

from pytorch_forecasting.data.examples import get_stallion_data
import pandas as pd

# 1. Load the stallion demand dataset
data = get_stallion_data()

# 2. Add time index (required for TimeSeriesDataSet)
data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month
data["time_idx"] -= data["time_idx"].min()

# 3. Add log-transformed volume
data["log_volume"] = data["volume"].apply(lambda x: x if x > 0 else 1).transform("log")

# 4. Add rolling average features per group
data["avg_volume_by_sku"] = data.groupby(
    ["time_idx", "sku"], observed=True
)["volume"].transform("mean")
data["avg_volume_by_agency"] = data.groupby(
    ["time_idx", "agency"], observed=True
)["volume"].transform("mean")

print(f"Loaded {len(data)} rows, {data['agency'].nunique() * data['sku'].nunique()} series")
print(f"Columns: {list(data.columns)}")

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment