Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Romsto Speculative Decoding Prune Cache

From Leeroopedia
Knowledge Sources
Domains Inference_Optimization, Memory_Management
Last Updated 2026-02-14 04:30 GMT

Overview

Concrete tool for removing trailing KV-cache entries after speculative draft rejection, supporting both tuple and DynamicCache formats.

Description

The prune_cache function is a dispatcher that removes the last N token entries from a transformer KV-cache. It handles two cache formats: the legacy tuple-of-tuples format (used by older HuggingFace models) and the newer DynamicCache format (used by transformers >= 4.36). The function delegates to prune_tuple_cache or prune_dynamic_cache based on the cache type.

For tuple caches, it creates new truncated tensors. For DynamicCache, it modifies the cache in-place by slicing the key_cache and value_cache lists and adjusting the _seen_tokens counter.

Usage

Import this function when using speculative decoding or NASD with KV-cache enabled. It is called internally after each speculative round to discard cache entries for rejected draft tokens. The number of tokens to discard depends on whether you are pruning the drafter cache (gamma - n rejected entries) or the target cache (gamma - n + 1 entries).

Code Reference

Source Location

Signature

def prune_cache(
    cache: Union[Tuple[Tuple[Tensor, Tensor]], DynamicCache],
    num_tokens_to_discard: int
) -> Union[Tuple[Tuple[Tensor, Tensor]], DynamicCache]:
    """
    Prune the cache by removing the specified number of tokens from the end.

    Args:
        cache: The KV cache to be pruned (tuple-of-tuples or DynamicCache).
        num_tokens_to_discard (int): The number of tokens to discard from the end.

    Returns:
        The pruned KV cache.
    """

def prune_tuple_cache(
    cache: Tuple[Tuple[Tensor, Tensor]],
    num_tokens_to_discard: int
) -> Tuple[Tuple[Tensor, Tensor]]:
    """Prune tuple-of-tuples cache format (most models)."""

def prune_dynamic_cache(
    cache: DynamicCache,
    num_tokens_to_discard: int
) -> DynamicCache:
    """Prune DynamicCache format (newer transformers). Modifies in place."""

Import

from utils.caching import prune_cache

I/O Contract

Inputs

Name Type Required Description
cache Union[Tuple[Tuple[Tensor, Tensor]], DynamicCache] Yes KV-cache from a transformer model forward pass. Can be None (returns None).
num_tokens_to_discard int Yes Number of trailing token positions to remove from the cache

Outputs

Name Type Description
pruned_cache Union[Tuple[Tuple[Tensor, Tensor]], DynamicCache] KV-cache with the last num_tokens_to_discard entries removed from the sequence dimension. For DynamicCache, the same instance modified in place.

Usage Examples

Pruning After Draft Rejection

from utils.caching import prune_cache

# After speculative decoding rejects at position n out of gamma drafts:
# Drafter has (gamma - n) excess entries
drafter_cache = prune_cache(drafter_cache, gamma - n)

# Target has (gamma - n + 1) excess entries
target_cache = prune_cache(target_cache, gamma - n + 1)

Handling None Cache

from utils.caching import prune_cache

# Safe to call with None (returns None)
cache = prune_cache(None, 5)  # Returns None

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment