Heuristic:Haotian liu LLaVA Gradient Checkpointing Memory Optimization
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Deep_Learning |
| Last Updated | 2026-02-13 23:00 GMT |
Overview
Enable gradient checkpointing with a forward hook fallback to reduce VRAM usage by ~50% during LLaVA training, at the cost of ~20% slower training speed.
Description
LLaVA enables gradient checkpointing in all its training scripts (`--gradient_checkpointing True`). The implementation includes a fallback mechanism: if the model does not natively support `enable_input_require_grads()`, a forward hook is registered on the embedding layer to manually set `requires_grad_(True)` on outputs. This ensures gradient checkpointing works across all supported model architectures (LLaMA, Mistral, MPT). For quantized training (4-bit/8-bit), `prepare_model_for_kbit_training` from peft handles gradient checkpointing setup automatically.
Usage
Use this heuristic whenever training LLaVA models, especially when VRAM is a bottleneck. All official training scripts set `--gradient_checkpointing True` by default. This is particularly important for 13B parameter models and full finetuning where activation memory dominates VRAM usage.
The Insight (Rule of Thumb)
- Action: Always set `--gradient_checkpointing True` in training arguments. For quantized (4-bit/8-bit) training, `prepare_model_for_kbit_training` handles it automatically when the flag is set.
- Value: ~50% reduction in peak VRAM usage.
- Trade-off: ~20% slower training speed due to recomputation of activations during backward pass.
- Compatibility: Must also set `model.config.use_cache = False` during training (done automatically in `train.py`). Re-enable cache after training for inference.
Reasoning
Large multimodal models like LLaVA-13B require storing massive activation tensors during the forward pass. Gradient checkpointing recomputes activations during the backward pass instead of storing them, trading compute for memory. The fallback hook mechanism is needed because some model architectures (e.g., older versions) do not implement `enable_input_require_grads()` natively. Without this fallback, gradient checkpointing would fail silently on those models.
The official V1.5 training scripts use per-device batch sizes of 32 (pretraining) and 16 (finetuning) with gradient checkpointing enabled, which would be impossible without it on typical GPU configurations.
Code Evidence
Gradient checkpointing with fallback from `train.py:852-858`:
if training_args.gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
Quantized model preparation from `train.py:847-850`:
if training_args.bits in [4, 8]:
from peft import prepare_model_for_kbit_training
model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
Cache disabled during training from `train.py:842`:
model.config.use_cache = False
V1.5 pretrain script using gradient checkpointing (`scripts/v1_5/pretrain.sh:32`):
--gradient_checkpointing True