Heuristic:Huggingface Alignment handbook Gradient Checkpointing Use Cache
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Deep_Learning |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Memory optimization: when gradient_checkpointing is enabled, use_cache must be set to False during training and restored to True after training for inference compatibility.
Description
Gradient checkpointing trades compute for memory by not storing intermediate activations during the forward pass, recomputing them during backward. However, the KV cache (use_cache) is incompatible with gradient checkpointing because the cache stores activations that gradient checkpointing intentionally discards. The alignment-handbook automatically handles this toggle: get_model sets use_cache=False when gradient_checkpointing=True, and sft.py restores use_cache=True after training for fast inference.
Usage
Apply this when training any model with gradient checkpointing enabled (which is the default in all alignment-handbook recipes). The code handles this automatically, but understanding the mechanism is important for debugging generation issues after training.
The Insight (Rule of Thumb)
- Action: Set `use_cache=False` during training when `gradient_checkpointing=True`; restore `use_cache=True` after training.
- Value: Boolean toggle, applied automatically by the code.
- Trade-off: Gradient checkpointing reduces VRAM usage at the cost of ~20% slower training. Forgetting to restore use_cache after training results in slow inference.
Reasoning
The KV cache stores key-value pairs from attention layers to speed up autoregressive generation. During training with gradient checkpointing, activations are recomputed during the backward pass. If use_cache is True, it creates a conflict because the cached values are not compatible with the recomputation strategy. All alignment-handbook recipes set `gradient_checkpointing: true` and the `get_model` function enforces the use_cache toggle.
The SFT script explicitly restores use_cache after training:
# From scripts/sft.py:149-151
# Restore k,v cache for fast inference
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(training_args.output_dir)
Code evidence from `src/alignment/model_utils.py:48`:
use_cache=False if training_args.gradient_checkpointing else True,
All recipe configs enable gradient checkpointing with use_reentrant control:
# From recipes/zephyr-7b-beta/sft/config_full.yaml:32-34
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: False