Heuristic:Huggingface Trl Disable Dropout For RL Training
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Reinforcement_Learning |
| Last Updated | 2026-02-06 17:00 GMT |
Overview
Disable dropout in both policy and reference models during RL-based training (GRPO, DPO, PPO) to ensure consistent log-probability computation across forward passes.
Description
In reinforcement learning training methods like GRPO and DPO, the trainer computes log-probabilities from both a policy model and a reference model. Dropout introduces stochasticity into these computations, meaning the same input can produce different logprobs across forward passes. This inconsistency can destabilize the KL divergence computation between policy and reference, leading to noisy gradients and potentially divergent training. Disabling dropout ensures deterministic forward passes, making the reward signal and KL penalty more reliable.
Usage
Apply this heuristic when training with any RL-based method that uses a reference model (GRPO with beta > 0, DPO, PPO). Set disable_dropout=True in your training config. TRL's GRPOTrainer supports this via the disable_dropout argument in GRPOConfig.
The Insight (Rule of Thumb)
- Action: Set
disable_dropout=Truein GRPOConfig (or equivalent config). - Value: Recursively sets all dropout modules to
p=0.0in both policy and reference models. - Trade-off: Loss of dropout regularization during training. For fine-tuning pre-trained models, this is typically acceptable since the base model was already trained with dropout.
Reasoning
The GRPO, DPO, and PPO algorithms all compute a ratio between the policy model's log-probabilities and the reference model's log-probabilities. If dropout is active, the same input produces different outputs on each forward pass. This means:
- The log-probability ratio becomes noisy, adding variance to the gradient signal.
- The KL divergence estimate between policy and reference becomes unreliable.
- Training may diverge or converge to suboptimal solutions.
By disabling dropout, all forward passes become deterministic (for the same input), giving clean reward signals and stable training dynamics.
Code evidence from `trl/trainer/grpo_trainer.py:572-576`:
# Disable dropout in the models
if args.disable_dropout:
disable_dropout_in_model(model)
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)