Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Haotian liu LLaVA Gradient Checkpointing Memory Optimization

From Leeroopedia
Knowledge Sources
Domains Optimization, Deep_Learning
Last Updated 2026-02-13 23:00 GMT

Overview

Enable gradient checkpointing with a forward hook fallback to reduce VRAM usage by ~50% during LLaVA training, at the cost of ~20% slower training speed.

Description

LLaVA enables gradient checkpointing in all its training scripts (`--gradient_checkpointing True`). The implementation includes a fallback mechanism: if the model does not natively support `enable_input_require_grads()`, a forward hook is registered on the embedding layer to manually set `requires_grad_(True)` on outputs. This ensures gradient checkpointing works across all supported model architectures (LLaMA, Mistral, MPT). For quantized training (4-bit/8-bit), `prepare_model_for_kbit_training` from peft handles gradient checkpointing setup automatically.

Usage

Use this heuristic whenever training LLaVA models, especially when VRAM is a bottleneck. All official training scripts set `--gradient_checkpointing True` by default. This is particularly important for 13B parameter models and full finetuning where activation memory dominates VRAM usage.

The Insight (Rule of Thumb)

  • Action: Always set `--gradient_checkpointing True` in training arguments. For quantized (4-bit/8-bit) training, `prepare_model_for_kbit_training` handles it automatically when the flag is set.
  • Value: ~50% reduction in peak VRAM usage.
  • Trade-off: ~20% slower training speed due to recomputation of activations during backward pass.
  • Compatibility: Must also set `model.config.use_cache = False` during training (done automatically in `train.py`). Re-enable cache after training for inference.

Reasoning

Large multimodal models like LLaVA-13B require storing massive activation tensors during the forward pass. Gradient checkpointing recomputes activations during the backward pass instead of storing them, trading compute for memory. The fallback hook mechanism is needed because some model architectures (e.g., older versions) do not implement `enable_input_require_grads()` natively. Without this fallback, gradient checkpointing would fail silently on those models.

The official V1.5 training scripts use per-device batch sizes of 32 (pretraining) and 16 (finetuning) with gradient checkpointing enabled, which would be impossible without it on typical GPU configurations.

Code Evidence

Gradient checkpointing with fallback from `train.py:852-858`:

if training_args.gradient_checkpointing:
    if hasattr(model, "enable_input_require_grads"):
        model.enable_input_require_grads()
    else:
        def make_inputs_require_grad(module, input, output):
            output.requires_grad_(True)
        model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

Quantized model preparation from `train.py:847-850`:

if training_args.bits in [4, 8]:
    from peft import prepare_model_for_kbit_training
    model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)

Cache disabled during training from `train.py:842`:

model.config.use_cache = False

V1.5 pretrain script using gradient checkpointing (`scripts/v1_5/pretrain.sh:32`):

--gradient_checkpointing True

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment