Heuristic:Huggingface Peft Gradient Checkpointing With Quantization
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Quantization, Training |
| Last Updated | 2026-02-07 06:44 GMT |
Overview
When using gradient checkpointing with quantized models, set `use_reentrant=False` to avoid the input requires_grad hack and improve compatibility.
Description
The `prepare_model_for_kbit_training()` function enables gradient checkpointing for quantized models. When `use_reentrant=True` (the default for PyTorch gradient checkpointing), a workaround hack is needed: either calling `model.enable_input_require_grads()` or registering a forward hook that forces `output.requires_grad_(True)`. Setting `use_reentrant=False` eliminates the need for this hack entirely, resulting in cleaner execution and better backward pass behavior.
Usage
Apply this heuristic whenever you are:
- Fine-tuning a model loaded with `BitsAndBytesConfig` (4-bit or 8-bit)
- Fine-tuning a GPTQ, AQLM, EETQ, HQQ, or TorchAO quantized model
- Getting CUDA OOM errors during backward pass with quantized training
- Seeing unexpected gradient issues with quantized models
The Insight (Rule of Thumb)
- Action: Pass `gradient_checkpointing_kwargs={"use_reentrant": False}` to `prepare_model_for_kbit_training()` or to `TrainingArguments`.
- Value: `use_reentrant=False`
- Trade-off: Negligible. The non-reentrant implementation is generally preferred and avoids potential issues with the `requires_grad` hack. Slightly different memory behavior but generally equivalent or better.
- Compatibility: Requires `transformers > 4.34.1` for the `gradient_checkpointing_kwargs` argument to be supported.
Reasoning
PyTorch's `torch.utils.checkpoint.checkpoint` has two modes: reentrant (default) and non-reentrant. The reentrant mode has known issues with models that have frozen parameters (common in PEFT), requiring input tensors to have `requires_grad=True`. The non-reentrant mode handles this correctly without requiring the workaround. PEFT's code explicitly documents this:
From `src/peft/utils/other.py:185-186`:
# When having `use_reentrant=False` + gradient_checkpointing,
# there is no need for this hack
if "use_reentrant" not in gradient_checkpointing_kwargs or \
gradient_checkpointing_kwargs["use_reentrant"]:
# For backward compatibility
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(
make_inputs_require_grad
)
Transformers version check from `src/peft/utils/other.py:198-207`:
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
inspect.signature(
model.gradient_checkpointing_enable
).parameters
)
if not _supports_gc_kwargs and len(gradient_checkpointing_kwargs) > 0:
warnings.warn(
"gradient_checkpointing_kwargs is not supported in this "
"version of transformers. The passed kwargs will be "
"ignored. if you want to use that feature, please upgrade "
"to the latest version of transformers.",
FutureWarning,
)