Implementation:Romsto Speculative Decoding Prune Cache
| Knowledge Sources | |
|---|---|
| Domains | Inference_Optimization, Memory_Management |
| Last Updated | 2026-02-14 04:30 GMT |
Overview
Concrete tool for removing trailing KV-cache entries after speculative draft rejection, supporting both tuple and DynamicCache formats.
Description
The prune_cache function is a dispatcher that removes the last N token entries from a transformer KV-cache. It handles two cache formats: the legacy tuple-of-tuples format (used by older HuggingFace models) and the newer DynamicCache format (used by transformers >= 4.36). The function delegates to prune_tuple_cache or prune_dynamic_cache based on the cache type.
For tuple caches, it creates new truncated tensors. For DynamicCache, it modifies the cache in-place by slicing the key_cache and value_cache lists and adjusting the _seen_tokens counter.
Usage
Import this function when using speculative decoding or NASD with KV-cache enabled. It is called internally after each speculative round to discard cache entries for rejected draft tokens. The number of tokens to discard depends on whether you are pruning the drafter cache (gamma - n rejected entries) or the target cache (gamma - n + 1 entries).
Code Reference
Source Location
- Repository: Speculative-Decoding
- File: utils/caching.py
- Lines: L6-77
Signature
def prune_cache(
cache: Union[Tuple[Tuple[Tensor, Tensor]], DynamicCache],
num_tokens_to_discard: int
) -> Union[Tuple[Tuple[Tensor, Tensor]], DynamicCache]:
"""
Prune the cache by removing the specified number of tokens from the end.
Args:
cache: The KV cache to be pruned (tuple-of-tuples or DynamicCache).
num_tokens_to_discard (int): The number of tokens to discard from the end.
Returns:
The pruned KV cache.
"""
def prune_tuple_cache(
cache: Tuple[Tuple[Tensor, Tensor]],
num_tokens_to_discard: int
) -> Tuple[Tuple[Tensor, Tensor]]:
"""Prune tuple-of-tuples cache format (most models)."""
def prune_dynamic_cache(
cache: DynamicCache,
num_tokens_to_discard: int
) -> DynamicCache:
"""Prune DynamicCache format (newer transformers). Modifies in place."""
Import
from utils.caching import prune_cache
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| cache | Union[Tuple[Tuple[Tensor, Tensor]], DynamicCache] | Yes | KV-cache from a transformer model forward pass. Can be None (returns None). |
| num_tokens_to_discard | int | Yes | Number of trailing token positions to remove from the cache |
Outputs
| Name | Type | Description |
|---|---|---|
| pruned_cache | Union[Tuple[Tuple[Tensor, Tensor]], DynamicCache] | KV-cache with the last num_tokens_to_discard entries removed from the sequence dimension. For DynamicCache, the same instance modified in place. |
Usage Examples
Pruning After Draft Rejection
from utils.caching import prune_cache
# After speculative decoding rejects at position n out of gamma drafts:
# Drafter has (gamma - n) excess entries
drafter_cache = prune_cache(drafter_cache, gamma - n)
# Target has (gamma - n + 1) excess entries
target_cache = prune_cache(target_cache, gamma - n + 1)
Handling None Cache
from utils.caching import prune_cache
# Safe to call with None (returns None)
cache = prune_cache(None, 5) # Returns None