Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:OpenRLHF OpenRLHF Gradient Checkpointing Memory Tip

From Leeroopedia



Knowledge Sources
Domains Optimization, LLMs, Deep_Learning
Last Updated 2026-02-07 10:00 GMT

Overview

Enable gradient checkpointing with `use_reentrant=False` to reduce VRAM usage by 50-60% when training large models.

Description

Gradient checkpointing (activation checkpointing) reduces peak GPU memory during training by not storing intermediate activations for the backward pass. Instead, activations are recomputed on-the-fly during backpropagation. OpenRLHF consistently uses `use_reentrant` as a configurable parameter (defaulting to `False` in modern PyTorch) across all training scripts. This technique is essential for training 7B+ parameter models on consumer or limited-VRAM hardware.

Usage

Use this heuristic when you encounter CUDA out of memory errors during training, or when you need to maximize batch size on limited GPU memory. Enable via the `--gradient_checkpointing` flag on any OpenRLHF training command. Applies to SFT, DPO, RM, KD, KTO, PRM, and PPO training.

The Insight (Rule of Thumb)

  • Action: Add `--gradient_checkpointing` to any training command.
  • Value: Boolean flag; optionally control reentrant mode with `--gradient_checkpointing_use_reentrant` (default: False).
  • Trade-off: Reduces peak VRAM by ~50-60% but increases training time by ~20-30% due to activation recomputation.
  • Compatibility: Works with all Transformer models. Requires `use_cache=False` during training (already enforced by OpenRLHF).

Reasoning

Deep Transformer models store large activation tensors (batch x sequence_length x hidden_size) during the forward pass for use in backpropagation. These activations are the primary VRAM bottleneck. By recomputing them during the backward pass, peak memory is significantly reduced. The `use_reentrant=False` option is the modern PyTorch default and is safer with complex control flow, avoiding subtle bugs that can occur with the reentrant variant.

Code evidence from `openrlhf/cli/train_dpo.py:54-56`:

if args.gradient_checkpointing:
    model.gradient_checkpointing_enable(
        gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant}
    )

This identical pattern appears in all training entry points: `train_sft.py`, `train_rm.py`, `train_kto.py`, `train_prm.py`, `train_kd.py`, and `train_ppo_ray.py`.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment