Heuristic:Huggingface Trl Gradient Checkpointing Use Reentrant
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Training |
| Last Updated | 2026-02-06 17:00 GMT |
Overview
Set use_reentrant=False in gradient_checkpointing_kwargs for reliable gradient checkpointing behavior in all TRL trainers.
Description
PyTorch's gradient checkpointing has two modes: reentrant and non-reentrant. The reentrant variant (use_reentrant=True) was the historical default but has known issues with certain model architectures, PEFT adapters, and hooks. PyTorch now recommends use_reentrant=False as the default. Hugging Face transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but never updated to the recommended non-reentrant behavior. TRL proactively sets the non-reentrant default for its trainers when using transformers < 5.0.0.
Usage
Apply this heuristic whenever enabling gradient_checkpointing=True in any TRL training configuration. This is especially critical when combining gradient checkpointing with PEFT/LoRA adapters, as the reentrant variant can cause silent gradient computation errors.
The Insight (Rule of Thumb)
- Action: Ensure
gradient_checkpointing_kwargs={"use_reentrant": False}is set in your training config. - Value:
use_reentrant=False(non-reentrant checkpointing). - Trade-off: Non-reentrant checkpointing may use slightly more memory in rare edge cases, but provides correct gradient computation across all model configurations.
- Compatibility: TRL automatically applies this default for transformers < 5.0.0. For transformers >= 5.0.0, this is already the upstream default.
Reasoning
The reentrant variant of gradient checkpointing has fundamental limitations: it does not support arbitrary model architectures with hooks, it can produce incorrect gradients with certain PEFT configurations, and it requires all inputs to be leaf tensors that require gradients. The non-reentrant variant handles all these cases correctly. TRL's GRPOTrainer (and by extension other RL trainers) uses gradient checkpointing together with PEFT adapters and model hooks (e.g., for dropout disable, lm_head fp32 casting), making non-reentrant behavior essential.
Code evidence from `trl/trainer/grpo_trainer.py:530-536`:
# Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning,
# but the default was never updated once PyTorch switched to recommending use_reentrant=False.
# Until that change lands upstream (see https://github.com/huggingface/transformers/pull/43203)
# and is released (most likely in 5.0.0), we default to the recommended non-reentrant behavior
# here, while preserving any user-provided value.
if args.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"):
args.gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
args.gradient_checkpointing_kwargs.setdefault("use_reentrant", False)