Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Mlfoundations Open flamingo KV Cache Classification Optimization

From Leeroopedia




Knowledge Sources
Domains Optimization, Evaluation, LLMs
Last Updated 2026-02-08 03:30 GMT

Overview

Key-value caching optimization for classification evaluation that pre-computes context representations once and reuses them across all class name evaluations, avoiding redundant forward passes.

Description

During classification evaluation (ImageNet, Hateful Memes), the model must compute log-probabilities for every class name given the same context (images + in-context demonstrations). Without caching, the full context would be re-processed for each of the potentially thousands of class names. KV-caching pre-computes the context representation once using `cache_media()` and a single forward pass with `use_cache=True`, then reuses the cached key-value pairs for each class name evaluation. This turns an O(N*C) computation into O(N+C) where N is context length and C is number of classes.

Usage

Apply this heuristic during Classification Evaluation (ImageNet with 1000 classes, Hateful Memes with 2 classes). It is enabled by default and can be disabled with `--no_caching_for_classification`. The optimization is most impactful when the number of in-context examples is large and there are many class names.

The Insight (Rule of Thumb)

  • Action: Use `cache_media()` to pre-encode vision features, then use `past_key_values` from a single context forward pass for all class name evaluations.
  • Value: Reduces forward passes from `num_classes * batch_size` to `batch_size + num_classes * batch_size * class_token_length` (much smaller).
  • Trade-off: Uses more GPU memory to store cached KV pairs. Can be disabled if memory is constrained.

Reasoning

In classification scoring, the context (images + demonstrations) is identical for every class name. Computing the full forward pass for each class name is wasteful. By caching the key-value states from the context, subsequent class name evaluations only need to process the class name tokens (typically 1-3 tokens) rather than the entire context (potentially hundreds of tokens). For ImageNet with 1000 classes and 8-shot demonstrations, this avoids ~999 redundant full context computations per sample.

Code Evidence

KV-cache classification from `open_flamingo/eval/models/open_flamingo.py:170-184`:

# Cache the context
if use_cache:
    # reserve the last token in the context for the main forward pass
    self.cache_media(
        input_ids=ctx_input_ids,
        vision_x=batch_images,
    )
    precomputed = self.__call__(
        vision_x=None,
        lang_x=ctx_input_ids,
        attention_mask=ctx_attention_mask,
        clear_conditioned_layers=False,
        use_cache=True,
    )
    precomputed_logits = precomputed.logits
    precomputed_pkvs = precomputed.past_key_values

Token-by-token forward with cached KVs from `open_flamingo/eval/models/open_flamingo.py:286-309`:

# loop to handle updating past_key_values
logits = []
for token_idx in range(lang_x.shape[1]):
    _lang_x = lang_x[:, token_idx].reshape((-1, 1))
    ...
    with torch.inference_mode():
        with self.autocast():
            outputs = self.model(
                vision_x=vision_x,
                lang_x=_lang_x,
                attention_mask=_attention_mask,
                clear_conditioned_layers=False,
                past_key_values=past_key_values,
                use_cache=True,
            )
    past_key_values = outputs.past_key_values
    logits.append(outputs.logits)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment