Heuristic:OpenRLHF OpenRLHF Gradient Checkpointing Memory Tip
| Knowledge Sources | |
|---|---|
| Domains | Optimization, LLMs, Deep_Learning |
| Last Updated | 2026-02-07 10:00 GMT |
Overview
Enable gradient checkpointing with `use_reentrant=False` to reduce VRAM usage by 50-60% when training large models.
Description
Gradient checkpointing (activation checkpointing) reduces peak GPU memory during training by not storing intermediate activations for the backward pass. Instead, activations are recomputed on-the-fly during backpropagation. OpenRLHF consistently uses `use_reentrant` as a configurable parameter (defaulting to `False` in modern PyTorch) across all training scripts. This technique is essential for training 7B+ parameter models on consumer or limited-VRAM hardware.
Usage
Use this heuristic when you encounter CUDA out of memory errors during training, or when you need to maximize batch size on limited GPU memory. Enable via the `--gradient_checkpointing` flag on any OpenRLHF training command. Applies to SFT, DPO, RM, KD, KTO, PRM, and PPO training.
The Insight (Rule of Thumb)
- Action: Add `--gradient_checkpointing` to any training command.
- Value: Boolean flag; optionally control reentrant mode with `--gradient_checkpointing_use_reentrant` (default: False).
- Trade-off: Reduces peak VRAM by ~50-60% but increases training time by ~20-30% due to activation recomputation.
- Compatibility: Works with all Transformer models. Requires `use_cache=False` during training (already enforced by OpenRLHF).
Reasoning
Deep Transformer models store large activation tensors (batch x sequence_length x hidden_size) during the forward pass for use in backpropagation. These activations are the primary VRAM bottleneck. By recomputing them during the backward pass, peak memory is significantly reduced. The `use_reentrant=False` option is the modern PyTorch default and is safer with complex control flow, avoiding subtle bugs that can occur with the reentrant variant.
Code evidence from `openrlhf/cli/train_dpo.py:54-56`:
if args.gradient_checkpointing:
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant}
)
This identical pattern appears in all training entry points: `train_sft.py`, `train_rm.py`, `train_kto.py`, `train_prm.py`, `train_kd.py`, and `train_ppo_ray.py`.