Heuristic:Sktime Pytorch forecasting Batch Size Selection
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Deep_Learning, Time_Series |
| Last Updated | 2026-02-08 08:00 GMT |
Overview
Set batch size between 32 and 128 for pytorch-forecasting models, with 64 as the default for multi-horizon models (TFT, DeepAR) and 128 for univariate models (N-BEATS).
Description
Batch size is a critical hyperparameter that affects training speed, memory usage, and generalization. The stallion tutorial explicitly recommends "set this between 32 to 128." All examples use values within this range. The `TimeBatchSampler` emits a warning when the batch size exceeds the number of available samples for certain prediction times, which can bias gradient estimates.
Usage
Apply this heuristic when creating DataLoaders via TimeSeriesDataSet.to_dataloader(). Choose a batch size that fits in GPU/CPU memory while providing stable gradient estimates. If the `TimeBatchSampler` warns about insufficient samples, reduce the batch size.
The Insight (Rule of Thumb)
- Action: Set `batch_size` in `TimeSeriesDataSet.to_dataloader(train=True, batch_size=N)`.
- Value: 64 for TFT and DeepAR, 128 for N-BEATS. Acceptable range: 32-128.
- Trade-off: Smaller batches (32) improve generalization but slow GPU utilization. Larger batches (128) speed up training but may hurt generalization and cause OOM on limited hardware.
Reasoning
Time series models process sequences of (encoder_length + prediction_length) time steps per sample. With encoder lengths of 100+ and multiple features, each sample consumes significant memory. The 32-128 range provides a sweet spot between gradient noise (too-small batches) and memory constraints (too-large batches).
Tutorial recommendation from stallion tutorial:
batch_size = 128 # set this between 32 to 128
Example values:
- `examples/stallion.py:97` — `batch_size=64`
- `examples/ar.py:53` — `batch_size=64`
- `examples/nbeats.py:51` — `batch_size=128`
- `README.md:114` — `batch_size=128`
Sampler warning from `data/samplers.py:95-100`:
warnings.warn(
f"Less than {self.batch_size} samples available for "
f"{len(warns)} prediction times. "
f"Use batch size smaller than {self.batch_size}. ..."
)