Heuristic:CarperAI Trlx Batch Size Tuning
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Reinforcement_Learning |
| Last Updated | 2026-02-07 16:00 GMT |
Overview
Memory and throughput optimization strategy for PPO training that scales batch_size, num_rollouts, and chunk_size inversely with model size.
Description
In trlx PPO training, three interrelated batch parameters control memory usage and training dynamics: `batch_size` (samples per gradient update), `num_rollouts` (experience collection size per iteration), and `chunk_size` (generation batch size). These must be tuned together based on available GPU memory and model size. Additionally, `minibatch_size` enables gradient accumulation when set smaller than `batch_size`, allowing larger effective batch sizes without proportionally more memory.
Usage
Apply this heuristic when configuring PPO training for any model size, or when encountering CUDA OOM errors during training. The batch parameters must be adjusted whenever changing the model architecture, model size, or available GPU hardware.
The Insight (Rule of Thumb)
- Action: Scale batch_size inversely with model parameter count. Use minibatch_size for gradient accumulation.
- Value:
- Small models (125M-1B): `batch_size=32`, `num_rollouts=128`, `chunk_size=128`
- Medium models (6B): `batch_size=4`, `num_rollouts=128`, `chunk_size=16`
- Large models (20B): `batch_size=1`, `num_rollouts=16`, `chunk_size=4`, `ppo_epochs=2`
- Trade-off: Smaller batch sizes reduce memory but increase gradient noise and slow convergence. Fewer num_rollouts mean less diverse experience per iteration.
- Constraint: `batch_size` must be evenly divisible by `minibatch_size` (enforced by assertion).
Reasoning
PPO requires storing rollout experiences (prompt + response tokens, logprobs, values, rewards) in memory simultaneously. For a 6B model, a single rollout occupies significantly more GPU memory than for a 125M model. The trlx codebase demonstrates this scaling through its example configurations, where the HH-RLHF examples explicitly define different batch parameters for each model size tier.
The `minibatch_size` parameter enables gradient accumulation: the batch is split into `batch_size / minibatch_size` forward passes, accumulating gradients before the optimizer step. This maintains the effective batch size while reducing peak memory.
Code Evidence
Model-size-dependent configuration from `examples/hh/ppo_hh.py:71-106`:
# Small model (125M-1B)
batch_size = 32
num_rollouts = 128
# Medium model (6B)
batch_size = 4
chunk_size = 16
# Large model (20B)
batch_size = 1
num_rollouts = 16
chunk_size = 4
ppo_epochs = 2
Minibatch divisibility assertion from `trlx/trainer/accelerate_base_trainer.py:49-54`:
if config.train.minibatch_size:
assert config.train.batch_size % config.train.minibatch_size == 0, \
"Minibatch size must divide batch size"
self.mb_size = config.train.minibatch_size
else:
self.mb_size = config.train.batch_size
self.num_mb = config.train.batch_size // self.mb_size
Global batch size computation from `trlx/trlx.py:100`:
batch_size = config.train.batch_size * int(os.environ.get("WORLD_SIZE", 1))