Principle:Romsto Speculative Decoding KV Cache Pruning
| Knowledge Sources | |
|---|---|
| Domains | Inference_Optimization, Memory_Management, Transformer_Architecture |
| Last Updated | 2026-02-14 04:30 GMT |
Overview
A cache management technique that removes key-value entries corresponding to rejected speculative draft tokens from the transformer KV-cache to maintain consistency between cache state and the accepted token sequence.
Description
KV-Cache Pruning addresses a critical bookkeeping problem in speculative decoding with KV-cache enabled. During standard autoregressive generation, the KV-cache grows monotonically as each new token's key and value projections are appended. However, in speculative decoding, when draft tokens are rejected, the KV-cache contains entries for tokens that are no longer part of the sequence. These stale entries must be removed to prevent the model from attending to non-existent positions in subsequent rounds.
The pruning operation truncates the sequence dimension of the KV-cache by removing the last N entries, where N corresponds to the number of rejected tokens plus potentially one additional entry for the correction position. This must be performed for both the drafter and target model caches, and must handle different cache formats (tuple-of-tuples for older models, DynamicCache for newer transformers).
Usage
Use this principle whenever implementing speculative decoding or NASD with KV-cache enabled. After determining the number of accepted draft tokens, prune the cache to discard entries for rejected drafts before the next speculative round. Without pruning, the cache would contain stale entries that corrupt subsequent generations.
Theoretical Basis
In a transformer with L layers and H attention heads, the KV-cache stores tensors of shape (batch_size, num_heads, sequence_length, head_dim) for both keys and values at each layer.
When n out of gamma drafts are accepted:
- The drafter cache has (gamma - n) excess entries that must be removed
- The target cache has (gamma - n + 1) excess entries (gamma draft positions minus n accepted, plus one for the correction position)
Pseudo-code:
# Abstract KV-cache pruning
def prune_kv_cache(cache, num_to_discard):
"""Remove trailing entries from KV-cache."""
for layer in cache:
layer.key = layer.key[:, :, :-num_to_discard, :]
layer.value = layer.value[:, :, :-num_to_discard, :]
return cache