Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:NVIDIA TransformerEngine KV Cache Inference

From Leeroopedia


Overview

Managing key-value caches during autoregressive inference to avoid redundant computation.

Description

During autoregressive generation, previously computed key and value tensors are cached so that each new token only requires computing attention with its own query (Q) against all accumulated key/value (K/V) pairs. TransformerEngine's InferenceParams manages this cache with support for both paged and non-paged memory modes, variable-length sequences within a batch, and multiple QKV formats (bshd, sbhd, thd).

The KV cache system integrates into the TE model hierarchy at two levels:

  • Memory allocation: Performed per-layer in MultiHeadAttention when a layer is first encountered during inference
  • Cache population and retrieval: Performed per-layer in DotProductAttention during each inference step

The cache management follows a pre-step / step pattern:

  1. pre_step(step_dict): Called before each model forward pass to update sequence tracking -- which sequences are active, which have finished, and how many new tokens each sequence processes
  2. step(layer_number, new_k, new_v): Called within each attention layer to copy new K/V tokens into the cache and return the full K/V tensors for attention computation

This two-phase design supports dynamic batching where sequences can be added or removed between steps, which is essential for continuous batching in production inference servers.

Theoretical Basis

In autoregressive generation, the attention output for token t depends on key and value projections from all tokens 1 through t:

Attention(Q_t, K_{1:t}, V_{1:t}) = softmax(Q_t * K_{1:t}^T / sqrt(d)) * V_{1:t}

Without caching, generating a sequence of length n requires:

  • Step 1: Compute K/V for token 1 (1 operation)
  • Step 2: Recompute K/V for tokens 1-2 (2 operations)
  • Step t: Recompute K/V for tokens 1-t (t operations)
  • Total: 1 + 2 + ... + n = O(n^2/2) K/V computations, but each step also requires O(t) attention, yielding O(n^3/3) total work

With KV caching:

  • Step 1: Compute and cache K/V for token 1
  • Step t: Compute K/V for token t only, append to cache, attend over cached K/V
  • Total: n K/V computations and O(n^2/2) attention work

This reduces total computation from O(n^3/3) to O(n^2/2), a significant speedup for long sequences.

Paged vs. Non-Paged KV Cache:

Aspect Non-Paged Paged
Memory layout Contiguous [batch, max_seq, heads, dim] Pages of [total_pages, page_size, heads, dim]
Memory efficiency Allocates for max sequence length upfront Allocates pages on demand as sequences grow
Fragmentation No fragmentation but may waste memory No wasted memory but page table overhead
Dynamic batching Requires reindexing when batch composition changes Uses page table for flexible sequence management

Usage

Use this principle when performing autoregressive text generation with TransformerEngine models. KV caching is required for efficient inference with the TE Gemma example and any other TE-based model used in generation mode.

The typical workflow is:

  1. Create an InferenceParams instance with cache configuration
  2. Before each forward pass, call pre_step() with a dictionary mapping sequence IDs to their step lengths
  3. Pass the InferenceParams to each TransformerLayer forward call
  4. The cache is automatically managed within each attention layer

Related

Sources

Domains

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment