Principle:NVIDIA TransformerEngine KV Cache Inference
Overview
Managing key-value caches during autoregressive inference to avoid redundant computation.
Description
During autoregressive generation, previously computed key and value tensors are cached so that each new token only requires computing attention with its own query (Q) against all accumulated key/value (K/V) pairs. TransformerEngine's InferenceParams manages this cache with support for both paged and non-paged memory modes, variable-length sequences within a batch, and multiple QKV formats (bshd, sbhd, thd).
The KV cache system integrates into the TE model hierarchy at two levels:
- Memory allocation: Performed per-layer in
MultiHeadAttentionwhen a layer is first encountered during inference - Cache population and retrieval: Performed per-layer in
DotProductAttentionduring each inference step
The cache management follows a pre-step / step pattern:
pre_step(step_dict): Called before each model forward pass to update sequence tracking -- which sequences are active, which have finished, and how many new tokens each sequence processesstep(layer_number, new_k, new_v): Called within each attention layer to copy new K/V tokens into the cache and return the full K/V tensors for attention computation
This two-phase design supports dynamic batching where sequences can be added or removed between steps, which is essential for continuous batching in production inference servers.
Theoretical Basis
In autoregressive generation, the attention output for token t depends on key and value projections from all tokens 1 through t:
Attention(Q_t, K_{1:t}, V_{1:t}) = softmax(Q_t * K_{1:t}^T / sqrt(d)) * V_{1:t}
Without caching, generating a sequence of length n requires:
- Step 1: Compute K/V for token 1 (1 operation)
- Step 2: Recompute K/V for tokens 1-2 (2 operations)
- Step t: Recompute K/V for tokens 1-t (t operations)
- Total: 1 + 2 + ... + n = O(n^2/2) K/V computations, but each step also requires O(t) attention, yielding O(n^3/3) total work
With KV caching:
- Step 1: Compute and cache K/V for token 1
- Step t: Compute K/V for token t only, append to cache, attend over cached K/V
- Total: n K/V computations and O(n^2/2) attention work
This reduces total computation from O(n^3/3) to O(n^2/2), a significant speedup for long sequences.
Paged vs. Non-Paged KV Cache:
| Aspect | Non-Paged | Paged |
|---|---|---|
| Memory layout | Contiguous [batch, max_seq, heads, dim] |
Pages of [total_pages, page_size, heads, dim]
|
| Memory efficiency | Allocates for max sequence length upfront | Allocates pages on demand as sequences grow |
| Fragmentation | No fragmentation but may waste memory | No wasted memory but page table overhead |
| Dynamic batching | Requires reindexing when batch composition changes | Uses page table for flexible sequence management |
Usage
Use this principle when performing autoregressive text generation with TransformerEngine models. KV caching is required for efficient inference with the TE Gemma example and any other TE-based model used in generation mode.
The typical workflow is:
- Create an
InferenceParamsinstance with cache configuration - Before each forward pass, call
pre_step()with a dictionary mapping sequence IDs to their step lengths - Pass the
InferenceParamsto eachTransformerLayerforward call - The cache is automatically managed within each attention layer