Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Huggingface Peft Gradient Checkpointing With Quantization

From Leeroopedia




Knowledge Sources
Domains Optimization, Quantization, Training
Last Updated 2026-02-07 06:44 GMT

Overview

When using gradient checkpointing with quantized models, set `use_reentrant=False` to avoid the input requires_grad hack and improve compatibility.

Description

The `prepare_model_for_kbit_training()` function enables gradient checkpointing for quantized models. When `use_reentrant=True` (the default for PyTorch gradient checkpointing), a workaround hack is needed: either calling `model.enable_input_require_grads()` or registering a forward hook that forces `output.requires_grad_(True)`. Setting `use_reentrant=False` eliminates the need for this hack entirely, resulting in cleaner execution and better backward pass behavior.

Usage

Apply this heuristic whenever you are:

  • Fine-tuning a model loaded with `BitsAndBytesConfig` (4-bit or 8-bit)
  • Fine-tuning a GPTQ, AQLM, EETQ, HQQ, or TorchAO quantized model
  • Getting CUDA OOM errors during backward pass with quantized training
  • Seeing unexpected gradient issues with quantized models

The Insight (Rule of Thumb)

  • Action: Pass `gradient_checkpointing_kwargs={"use_reentrant": False}` to `prepare_model_for_kbit_training()` or to `TrainingArguments`.
  • Value: `use_reentrant=False`
  • Trade-off: Negligible. The non-reentrant implementation is generally preferred and avoids potential issues with the `requires_grad` hack. Slightly different memory behavior but generally equivalent or better.
  • Compatibility: Requires `transformers > 4.34.1` for the `gradient_checkpointing_kwargs` argument to be supported.

Reasoning

PyTorch's `torch.utils.checkpoint.checkpoint` has two modes: reentrant (default) and non-reentrant. The reentrant mode has known issues with models that have frozen parameters (common in PEFT), requiring input tensors to have `requires_grad=True`. The non-reentrant mode handles this correctly without requiring the workaround. PEFT's code explicitly documents this:

From `src/peft/utils/other.py:185-186`:

# When having `use_reentrant=False` + gradient_checkpointing,
# there is no need for this hack
if "use_reentrant" not in gradient_checkpointing_kwargs or \
    gradient_checkpointing_kwargs["use_reentrant"]:
    # For backward compatibility
    if hasattr(model, "enable_input_require_grads"):
        model.enable_input_require_grads()
    else:
        def make_inputs_require_grad(module, input, output):
            output.requires_grad_(True)
        model.get_input_embeddings().register_forward_hook(
            make_inputs_require_grad
        )

Transformers version check from `src/peft/utils/other.py:198-207`:

_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
    inspect.signature(
        model.gradient_checkpointing_enable
    ).parameters
)

if not _supports_gc_kwargs and len(gradient_checkpointing_kwargs) > 0:
    warnings.warn(
        "gradient_checkpointing_kwargs is not supported in this "
        "version of transformers. The passed kwargs will be "
        "ignored. if you want to use that feature, please upgrade "
        "to the latest version of transformers.",
        FutureWarning,
    )

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment