Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Huggingface Trl Gradient Checkpointing Use Reentrant

From Leeroopedia




Knowledge Sources
Domains Optimization, Training
Last Updated 2026-02-06 17:00 GMT

Overview

Set use_reentrant=False in gradient_checkpointing_kwargs for reliable gradient checkpointing behavior in all TRL trainers.

Description

PyTorch's gradient checkpointing has two modes: reentrant and non-reentrant. The reentrant variant (use_reentrant=True) was the historical default but has known issues with certain model architectures, PEFT adapters, and hooks. PyTorch now recommends use_reentrant=False as the default. Hugging Face transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but never updated to the recommended non-reentrant behavior. TRL proactively sets the non-reentrant default for its trainers when using transformers < 5.0.0.

Usage

Apply this heuristic whenever enabling gradient_checkpointing=True in any TRL training configuration. This is especially critical when combining gradient checkpointing with PEFT/LoRA adapters, as the reentrant variant can cause silent gradient computation errors.

The Insight (Rule of Thumb)

  • Action: Ensure gradient_checkpointing_kwargs={"use_reentrant": False} is set in your training config.
  • Value: use_reentrant=False (non-reentrant checkpointing).
  • Trade-off: Non-reentrant checkpointing may use slightly more memory in rare edge cases, but provides correct gradient computation across all model configurations.
  • Compatibility: TRL automatically applies this default for transformers < 5.0.0. For transformers >= 5.0.0, this is already the upstream default.

Reasoning

The reentrant variant of gradient checkpointing has fundamental limitations: it does not support arbitrary model architectures with hooks, it can produce incorrect gradients with certain PEFT configurations, and it requires all inputs to be leaf tensors that require gradients. The non-reentrant variant handles all these cases correctly. TRL's GRPOTrainer (and by extension other RL trainers) uses gradient checkpointing together with PEFT adapters and model hooks (e.g., for dropout disable, lm_head fp32 casting), making non-reentrant behavior essential.

Code evidence from `trl/trainer/grpo_trainer.py:530-536`:

# Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning,
# but the default was never updated once PyTorch switched to recommending use_reentrant=False.
# Until that change lands upstream (see https://github.com/huggingface/transformers/pull/43203)
# and is released (most likely in 5.0.0), we default to the recommended non-reentrant behavior
# here, while preserving any user-provided value.
if args.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"):
    args.gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
    args.gradient_checkpointing_kwargs.setdefault("use_reentrant", False)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment