Heuristic:Axolotl ai cloud Axolotl Gradient Checkpointing Reentrant Rules
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Memory_Management, Debugging |
| Last Updated | 2026-02-06 22:33 GMT |
Overview
Rules for configuring `use_reentrant` in gradient checkpointing to avoid silent correctness bugs and CheckpointErrors with frozen parameters, QLoRA, and distributed RL.
Description
PyTorch's gradient checkpointing has two implementations: reentrant (legacy, default in Axolotl for standard cases) and non-reentrant (newer, required for certain configurations). The choice between them is critical: using the wrong one causes either silent gradient computation errors with frozen parameters or explicit CheckpointErrors with QLoRA + DeepSpeed ZeRO3. Axolotl's validation layer enforces these rules at config time.
Usage
Apply these rules when gradient_checkpointing: true is set in the training config, especially when combined with `unfrozen_parameters` (partial freezing), `adapter: qlora` with DeepSpeed, or distributed RL training.
The Insight (Rule of Thumb)
- Rule 1 - Frozen Parameters: When using `unfrozen_parameters` (partially frozen model), MUST set `gradient_checkpointing_kwargs.use_reentrant: False`. Using `True` causes silent incorrect gradients. Reference: transformers issue #21381.
- Rule 2 - Default for Standard Training: When no `unfrozen_parameters`, no custom `gradient_checkpointing_kwargs`, and no RL, Axolotl defaults to `use_reentrant: True`.
- Rule 3 - QLoRA + ZeRO3 Warning: QLoRA + DeepSpeed ZeRO3 + `use_reentrant: False` may cause `CheckpointError: Recomputed values have different metadata`. This is a known upstream issue.
- Rule 4 - Distributed RL + QLoRA: Multi-GPU RL training with QLoRA MUST use `use_reentrant: False`. The reentrant implementation is broken upstream in TRL for this combination.
- Rule 5 - use_cache Disabled: Gradient checkpointing automatically disables `use_cache=True` during training (incompatible).
- Trade-off: `use_reentrant: True` is faster but incompatible with frozen params. `use_reentrant: False` is universally safe but slightly slower.
Reasoning
The reentrant implementation of `torch.utils.checkpoint` uses `torch.autograd.Function` which does not properly handle frozen parameters (parameters with `requires_grad=False`). When some layers are frozen and others are not, the reentrant version may silently compute incorrect gradients for the unfrozen layers. The non-reentrant version uses `torch.autograd.graph.saved_tensors_hooks` which correctly tracks gradient flow regardless of frozen state.
The QLoRA + ZeRO3 incompatibility with non-reentrant mode stems from tensor metadata changes during the DeepSpeed ZeRO3 parameter gathering/scattering process, which causes mismatches when the checkpoint function tries to verify recomputed values against stored metadata.
Code Evidence
Frozen parameter validation from `src/axolotl/utils/schemas/validation.py:439-450`:
@model_validator(mode="before")
@classmethod
def check_use_reentrant_mismatch(cls, data):
if (
data.get("unfrozen_parameters")
and data.get("gradient_checkpointing_kwargs")
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
is True
):
# https://github.com/huggingface/transformers/issues/21381
raise ValueError(
"`use_reentrant` must be false when used with partially frozen model."
)
Default reentrant setting from `src/axolotl/utils/config/__init__.py:234-240`:
if (
cfg.gradient_checkpointing
and cfg.unfrozen_parameters is None
and cfg.gradient_checkpointing_kwargs is None
and cfg.rl is None
):
cfg.gradient_checkpointing_kwargs = {"use_reentrant": True}
QLoRA + ZeRO3 warning from `src/axolotl/utils/schemas/validation.py:630-646`:
if (
data.get("adapter") == "qlora"
and data.get("gradient_checkpointing_kwargs", {})
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") is False
and data.get("deepspeed", "") is not None
and "zero3" in data.get("deepspeed", "")
):
LOG.warning(
"qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values"
)
Distributed RL + QLoRA enforcement from `src/axolotl/utils/schemas/validation.py:727-747`:
if (
data.get("rl")
and data.get("gradient_checkpointing")
and data.get("gradient_checkpointing_kwargs")
and data.get("gradient_checkpointing_kwargs").get("use_reentrant")
and data.get("load_in_4bit")
and data.get("adapter") == "qlora"
and data.get("capabilities")
and data.get("capabilities").get("n_gpu", 1) > 1
):
raise ValueError(
"The `use_reentrant: True` implementation of gradient checkpointing "
"is not supported for distributed RL training with QLoRA."
)