Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Microsoft DeepSpeedExamples RLHF Stability Constraints

From Leeroopedia



Knowledge Sources
Domains RLHF, Optimization, Debugging
Last Updated 2026-02-07 13:00 GMT

Overview

Critical stability constraint for RLHF PPO training: generation and training batch sizes must be equal, with single PPO epoch and single generation batch, to prevent training divergence.

Description

During RLHF Step 3 (PPO fine-tuning), the training process generates text with the actor model and then trains on that generated data. A subtle but critical bug occurs when the generation batch size differs from the training batch size, or when multiple PPO epochs are used: the `log_probs` and `old_log_probs` tensors diverge within consecutive iterations, causing the PPO loss to become unstable and eventually diverge. This is a non-negotiable constraint discovered through empirical testing.

Usage

Apply this heuristic immediately when setting up RLHF Step 3 training. If you observe NaN losses, reward collapse, or training divergence during PPO fine-tuning, verify these constraints first. This is the most common cause of RLHF training failure in DeepSpeed-Chat.

The Insight (Rule of Thumb)

  • Action: Set the following three constraints in your training script:
    • `per_device_generation_batch_size` == `per_device_training_batch_size`
    • `ppo_epochs` = 1
    • `generation_batches` = 1
  • Value: Recommended: `per_device_generation_batch_size=4`, `per_device_training_batch_size=4`, `ppo_epochs=1`, `generation_batches=1`
  • Trade-off: Limits effective batch diversity per PPO update; compensate with gradient accumulation if needed.
  • Severity: CRITICAL - violating any of these constraints will cause training divergence.

Reasoning

The PPO algorithm computes the ratio between current policy log-probabilities (`log_probs`) and the reference log-probabilities (`old_log_probs`). When `generation_batches > 1` or `ppo_epochs > 1`, the reference log-probabilities become stale within the same iteration, causing the importance sampling ratio to explode. When generation and training batch sizes differ, the tensor shapes create misaligned comparisons. Both scenarios lead to unbounded policy updates and eventual divergence.

From `applications/DeepSpeed-Chat/training/README.md`:

The divergence is specifically traced to how `log_probs` and `old_log_probs` interact across consecutive PPO iterations. The recommended configuration has been empirically validated on OPT-1.3B through OPT-66B model scales.

Code Evidence:

Training script configuration from `training/step3_rlhf_finetuning/training_scripts/opt/single_node/run_1.3b.sh:42-45`:

   --per_device_generation_batch_size 4 \
   --per_device_training_batch_size 4 \
   --generation_batches 1 \
   --ppo_epochs 1 \

Warning about non-default ZeRO stages from `applications/DeepSpeed-Chat/e2e_rlhf.py:101-102`:

warnings.warn("Non-default zero stages may result in OOM errors or worse performance.")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment