Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Haotian liu LLaVA Use Cache Training Inference Toggle

From Leeroopedia
Knowledge Sources
Domains Optimization, Deep_Learning
Last Updated 2026-02-13 23:00 GMT

Overview

Disable KV-cache (`use_cache=False`) during training for memory efficiency and compatibility with gradient checkpointing; re-enable it (`use_cache=True`) for inference speed.

Description

The LLaVA training pipeline explicitly toggles the `use_cache` configuration flag: it is set to `False` at the start of training (line 842) and restored to `True` after training completes (line 972). During inference (CLI, model worker, evaluation scripts), `use_cache=True` is always passed to `model.generate()`. This toggle is essential because KV-cache stores past key-value pairs to speed up autoregressive generation, but during training it wastes memory and is incompatible with gradient checkpointing.

Usage

Apply this heuristic in any training code: always disable `use_cache` before calling `trainer.train()`. In inference code, always enable `use_cache=True` in `model.generate()` calls. This is non-negotiable when gradient checkpointing is enabled.

The Insight (Rule of Thumb)

  • Action: Set `model.config.use_cache = False` before training. Set `use_cache=True` in `model.generate()` during inference.
  • Value: Training: saves memory and avoids conflict with gradient checkpointing. Inference: ~2-3x faster generation through KV-cache reuse.
  • Trade-off: None — this is a mandatory toggle, not an optional optimization.
  • Failure mode: Leaving `use_cache=True` during training with gradient checkpointing will cause a runtime error or incorrect gradients.

Reasoning

During training, the model processes full sequences in a single forward pass (teacher forcing), so there is no benefit to caching past key-value pairs — every token sees all other tokens simultaneously. Storing the cache during training wastes VRAM proportional to `batch_size * num_layers * seq_len * hidden_size`. During inference (autoregressive generation), each new token only needs to attend to previous tokens, so caching past key-value pairs avoids redundant computation. The cache is also incompatible with gradient checkpointing because checkpointing recomputes activations during the backward pass, which conflicts with cached states.

Code Evidence

Cache disabled at start of training from `train.py:842`:

model.config.use_cache = False

Cache re-enabled after training from `train.py:972`:

model.config.use_cache = True

Cache enabled during CLI inference from `cli.py:95-104`:

with torch.inference_mode():
    output_ids = model.generate(
        input_ids,
        images=image_tensor,
        image_sizes=[image_size],
        do_sample=True if args.temperature > 0 else False,
        temperature=args.temperature,
        max_new_tokens=args.max_new_tokens,
        streamer=streamer,
        use_cache=True)

Cache enabled in model worker from `model_worker.py:176-184`:

thread = Thread(target=model.generate, kwargs=dict(
    inputs=input_ids,
    do_sample=do_sample,
    temperature=temperature,
    top_p=top_p,
    max_new_tokens=max_new_tokens,
    streamer=streamer,
    use_cache=True,
    **image_args
))

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment