Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Axolotl ai cloud Axolotl Gradient Checkpointing Reentrant Rules

From Leeroopedia




Knowledge Sources
Domains Optimization, Memory_Management, Debugging
Last Updated 2026-02-06 22:33 GMT

Overview

Rules for configuring `use_reentrant` in gradient checkpointing to avoid silent correctness bugs and CheckpointErrors with frozen parameters, QLoRA, and distributed RL.

Description

PyTorch's gradient checkpointing has two implementations: reentrant (legacy, default in Axolotl for standard cases) and non-reentrant (newer, required for certain configurations). The choice between them is critical: using the wrong one causes either silent gradient computation errors with frozen parameters or explicit CheckpointErrors with QLoRA + DeepSpeed ZeRO3. Axolotl's validation layer enforces these rules at config time.

Usage

Apply these rules when gradient_checkpointing: true is set in the training config, especially when combined with `unfrozen_parameters` (partial freezing), `adapter: qlora` with DeepSpeed, or distributed RL training.

The Insight (Rule of Thumb)

  • Rule 1 - Frozen Parameters: When using `unfrozen_parameters` (partially frozen model), MUST set `gradient_checkpointing_kwargs.use_reentrant: False`. Using `True` causes silent incorrect gradients. Reference: transformers issue #21381.
  • Rule 2 - Default for Standard Training: When no `unfrozen_parameters`, no custom `gradient_checkpointing_kwargs`, and no RL, Axolotl defaults to `use_reentrant: True`.
  • Rule 3 - QLoRA + ZeRO3 Warning: QLoRA + DeepSpeed ZeRO3 + `use_reentrant: False` may cause `CheckpointError: Recomputed values have different metadata`. This is a known upstream issue.
  • Rule 4 - Distributed RL + QLoRA: Multi-GPU RL training with QLoRA MUST use `use_reentrant: False`. The reentrant implementation is broken upstream in TRL for this combination.
  • Rule 5 - use_cache Disabled: Gradient checkpointing automatically disables `use_cache=True` during training (incompatible).
  • Trade-off: `use_reentrant: True` is faster but incompatible with frozen params. `use_reentrant: False` is universally safe but slightly slower.

Reasoning

The reentrant implementation of `torch.utils.checkpoint` uses `torch.autograd.Function` which does not properly handle frozen parameters (parameters with `requires_grad=False`). When some layers are frozen and others are not, the reentrant version may silently compute incorrect gradients for the unfrozen layers. The non-reentrant version uses `torch.autograd.graph.saved_tensors_hooks` which correctly tracks gradient flow regardless of frozen state.

The QLoRA + ZeRO3 incompatibility with non-reentrant mode stems from tensor metadata changes during the DeepSpeed ZeRO3 parameter gathering/scattering process, which causes mismatches when the checkpoint function tries to verify recomputed values against stored metadata.

Code Evidence

Frozen parameter validation from `src/axolotl/utils/schemas/validation.py:439-450`:

@model_validator(mode="before")
@classmethod
def check_use_reentrant_mismatch(cls, data):
    if (
        data.get("unfrozen_parameters")
        and data.get("gradient_checkpointing_kwargs")
        and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
        is True
    ):
        # https://github.com/huggingface/transformers/issues/21381
        raise ValueError(
            "`use_reentrant` must be false when used with partially frozen model."
        )

Default reentrant setting from `src/axolotl/utils/config/__init__.py:234-240`:

if (
    cfg.gradient_checkpointing
    and cfg.unfrozen_parameters is None
    and cfg.gradient_checkpointing_kwargs is None
    and cfg.rl is None
):
    cfg.gradient_checkpointing_kwargs = {"use_reentrant": True}

QLoRA + ZeRO3 warning from `src/axolotl/utils/schemas/validation.py:630-646`:

if (
    data.get("adapter") == "qlora"
    and data.get("gradient_checkpointing_kwargs", {})
    and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") is False
    and data.get("deepspeed", "") is not None
    and "zero3" in data.get("deepspeed", "")
):
    LOG.warning(
        "qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values"
    )

Distributed RL + QLoRA enforcement from `src/axolotl/utils/schemas/validation.py:727-747`:

if (
    data.get("rl")
    and data.get("gradient_checkpointing")
    and data.get("gradient_checkpointing_kwargs")
    and data.get("gradient_checkpointing_kwargs").get("use_reentrant")
    and data.get("load_in_4bit")
    and data.get("adapter") == "qlora"
    and data.get("capabilities")
    and data.get("capabilities").get("n_gpu", 1) > 1
):
    raise ValueError(
        "The `use_reentrant: True` implementation of gradient checkpointing "
        "is not supported for distributed RL training with QLoRA."
    )

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment