Heuristic:Microsoft DeepSpeedExamples RLHF Stability Constraints
| Knowledge Sources | |
|---|---|
| Domains | RLHF, Optimization, Debugging |
| Last Updated | 2026-02-07 13:00 GMT |
Overview
Critical stability constraint for RLHF PPO training: generation and training batch sizes must be equal, with single PPO epoch and single generation batch, to prevent training divergence.
Description
During RLHF Step 3 (PPO fine-tuning), the training process generates text with the actor model and then trains on that generated data. A subtle but critical bug occurs when the generation batch size differs from the training batch size, or when multiple PPO epochs are used: the `log_probs` and `old_log_probs` tensors diverge within consecutive iterations, causing the PPO loss to become unstable and eventually diverge. This is a non-negotiable constraint discovered through empirical testing.
Usage
Apply this heuristic immediately when setting up RLHF Step 3 training. If you observe NaN losses, reward collapse, or training divergence during PPO fine-tuning, verify these constraints first. This is the most common cause of RLHF training failure in DeepSpeed-Chat.
The Insight (Rule of Thumb)
- Action: Set the following three constraints in your training script:
- `per_device_generation_batch_size` == `per_device_training_batch_size`
- `ppo_epochs` = 1
- `generation_batches` = 1
- Value: Recommended: `per_device_generation_batch_size=4`, `per_device_training_batch_size=4`, `ppo_epochs=1`, `generation_batches=1`
- Trade-off: Limits effective batch diversity per PPO update; compensate with gradient accumulation if needed.
- Severity: CRITICAL - violating any of these constraints will cause training divergence.
Reasoning
The PPO algorithm computes the ratio between current policy log-probabilities (`log_probs`) and the reference log-probabilities (`old_log_probs`). When `generation_batches > 1` or `ppo_epochs > 1`, the reference log-probabilities become stale within the same iteration, causing the importance sampling ratio to explode. When generation and training batch sizes differ, the tensor shapes create misaligned comparisons. Both scenarios lead to unbounded policy updates and eventual divergence.
From `applications/DeepSpeed-Chat/training/README.md`:
The divergence is specifically traced to how `log_probs` and `old_log_probs` interact across consecutive PPO iterations. The recommended configuration has been empirically validated on OPT-1.3B through OPT-66B model scales.
Code Evidence:
Training script configuration from `training/step3_rlhf_finetuning/training_scripts/opt/single_node/run_1.3b.sh:42-45`:
--per_device_generation_batch_size 4 \
--per_device_training_batch_size 4 \
--generation_batches 1 \
--ppo_epochs 1 \
Warning about non-default ZeRO stages from `applications/DeepSpeed-Chat/e2e_rlhf.py:101-102`:
warnings.warn("Non-default zero stages may result in OOM errors or worse performance.")