Heuristic:Mlfoundations Open flamingo KV Cache Classification Optimization
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Evaluation, LLMs |
| Last Updated | 2026-02-08 03:30 GMT |
Overview
Key-value caching optimization for classification evaluation that pre-computes context representations once and reuses them across all class name evaluations, avoiding redundant forward passes.
Description
During classification evaluation (ImageNet, Hateful Memes), the model must compute log-probabilities for every class name given the same context (images + in-context demonstrations). Without caching, the full context would be re-processed for each of the potentially thousands of class names. KV-caching pre-computes the context representation once using `cache_media()` and a single forward pass with `use_cache=True`, then reuses the cached key-value pairs for each class name evaluation. This turns an O(N*C) computation into O(N+C) where N is context length and C is number of classes.
Usage
Apply this heuristic during Classification Evaluation (ImageNet with 1000 classes, Hateful Memes with 2 classes). It is enabled by default and can be disabled with `--no_caching_for_classification`. The optimization is most impactful when the number of in-context examples is large and there are many class names.
The Insight (Rule of Thumb)
- Action: Use `cache_media()` to pre-encode vision features, then use `past_key_values` from a single context forward pass for all class name evaluations.
- Value: Reduces forward passes from `num_classes * batch_size` to `batch_size + num_classes * batch_size * class_token_length` (much smaller).
- Trade-off: Uses more GPU memory to store cached KV pairs. Can be disabled if memory is constrained.
Reasoning
In classification scoring, the context (images + demonstrations) is identical for every class name. Computing the full forward pass for each class name is wasteful. By caching the key-value states from the context, subsequent class name evaluations only need to process the class name tokens (typically 1-3 tokens) rather than the entire context (potentially hundreds of tokens). For ImageNet with 1000 classes and 8-shot demonstrations, this avoids ~999 redundant full context computations per sample.
Code Evidence
KV-cache classification from `open_flamingo/eval/models/open_flamingo.py:170-184`:
# Cache the context
if use_cache:
# reserve the last token in the context for the main forward pass
self.cache_media(
input_ids=ctx_input_ids,
vision_x=batch_images,
)
precomputed = self.__call__(
vision_x=None,
lang_x=ctx_input_ids,
attention_mask=ctx_attention_mask,
clear_conditioned_layers=False,
use_cache=True,
)
precomputed_logits = precomputed.logits
precomputed_pkvs = precomputed.past_key_values
Token-by-token forward with cached KVs from `open_flamingo/eval/models/open_flamingo.py:286-309`:
# loop to handle updating past_key_values
logits = []
for token_idx in range(lang_x.shape[1]):
_lang_x = lang_x[:, token_idx].reshape((-1, 1))
...
with torch.inference_mode():
with self.autocast():
outputs = self.model(
vision_x=vision_x,
lang_x=_lang_x,
attention_mask=_attention_mask,
clear_conditioned_layers=False,
past_key_values=past_key_values,
use_cache=True,
)
past_key_values = outputs.past_key_values
logits.append(outputs.logits)