Implementation:NVIDIA TransformerEngine InferenceParams
Overview
Concrete tool for managing KV caches during autoregressive inference provided by TransformerEngine.
Description
InferenceParams manages KV cache allocation, population, and retrieval across Transformer layers during inference. It supports both paged and non-paged cache modes, configurable QKV formats (bshd, sbhd, thd), and per-layer cache allocation. The class delegates actual cache storage and manipulation to a KVCacheManager implementation -- either NonPagedKVCacheManager or PagedKVCacheManager -- selected based on the is_paged parameter.
The class maintains cumulative sequence length tensors (cu_seqlens_q and cu_seqlens_kv) that are updated during pre_step() and used by fused attention kernels. It also tracks a pre_step_seqlens buffer containing the running length of each sequence before the current step, which is used for applying RoPE embeddings during inference.
The integration point within the TE model hierarchy is:
TransformerLayer
-> MultiHeadAttention
if layer_number not in inference_params.cache_manager.cache:
inference_params.allocate_memory(layer_number)
-> DotProductAttention
if inference_params is not None:
k_cache, v_cache, ... = inference_params.step(layer_number, new_k, new_v, qkv_format)
output = attention(new_q, k_cache, v_cache, ...)
Source
transformer_engine/pytorch/attention/inference.py, class InferenceParams at lines 55-228, __init__ at lines 127-228.
Import
from transformer_engine.pytorch.attention import InferenceParams
Signature
class InferenceParams:
def __init__(
self,
max_batch_size: int,
max_sequence_length: int,
num_heads_kv: int = None,
head_dim_k: int = None,
dtype: torch.dtype = None,
head_dim_v: int = None,
is_paged: bool = False,
total_num_pages: int = None,
page_size: int = None,
max_ctx_len: int = None,
qkv_format: str = "bshd",
custom_cache_manager: KVCacheManager = None,
):
...
def reset(self):
"""Reset InferenceParams state"""
...
def allocate_memory(self, layer_number: int):
"""Allocate KV cache memory for a specific layer"""
...
def pre_step(self, step_dict: OrderedDict):
"""Update tracked sequences and prepare cumulative seqlens for step()"""
...
def step(
self,
layer_number: int,
new_k: torch.Tensor,
new_v: torch.Tensor,
qkv_format: str,
) -> tuple:
"""Copy new KV tokens to cache and return full cache tensors"""
...
def get_seqlens_pre_step(self) -> torch.Tensor:
"""Get cached sequence lengths before the stepping"""
...
def convert_paged_to_nonpaged(self, layer_number: int) -> tuple:
"""Convert paged KV cache to non-paged format for a layer"""
...
I/O
Constructor Input:
| Parameter | Type | Default | Description |
|---|---|---|---|
max_batch_size |
int |
required | Maximum batch size during inference |
max_sequence_length |
int |
required | Maximum sequence length during inference |
num_heads_kv |
int |
None |
Number of KV attention heads (required since TE 2.2) |
head_dim_k |
int |
None |
Head dimension for keys (required since TE 2.2) |
dtype |
torch.dtype |
None |
Data type for KV cache tensors (required since TE 2.2) |
head_dim_v |
int |
None |
Head dimension for values; defaults to head_dim_k
|
is_paged |
bool |
False |
Whether to use paged KV cache |
total_num_pages |
int |
None |
Total pages for paged cache; required when is_paged=True
|
page_size |
int |
None |
Page size for paged cache; required when is_paged=True
|
max_ctx_len |
int |
None |
Maximum context length; required when qkv_format="thd"
|
qkv_format |
str |
"bshd" |
Format of incoming QKV tensors: "bshd", "sbhd", or "thd"
|
custom_cache_manager |
KVCacheManager |
None |
Custom cache manager subclass |
pre_step() Input:
step_dict:OrderedDict[int, int]-- Mapping of sequence IDs to step lengths for the current iteration
step() Input:
layer_number:int-- Layer number for cache lookupnew_k:torch.Tensor-- New key tokens for current iterationnew_v:torch.Tensor-- New value tokens for current iterationqkv_format:str-- Format of the new K/V tensors
step() Output:
k_cache:torch.Tensor-- Full key cache including new tokensv_cache:torch.Tensor-- Full value cache including new tokenscu_seqlens_q:torch.Tensor-- Cumulative query sequence lengths[batch_size + 1]cu_seqlens_kv:torch.Tensor-- Cumulative KV sequence lengths[batch_size + 1]max_sequence_length:int-- Maximum sequence lengthqkv_format:str-- Updated QKV format (e.g."thd"becomes"thd_2bshd")
Key Parameters
max_batch_size: Determines the size of pre-allocated cumulative sequence length buffers and the maximum batch dimension of the KV cache.max_sequence_length: Determines the sequence dimension of the KV cache. For paged mode, must be divisible bypage_size.num_heads_kv: Number of key-value attention heads, supporting grouped-query attention configurations.head_dim_k: Per-head dimension for keys;head_dim_vdefaults to the same value if not specified.dtype: Data type for the cache tensors (typicallytorch.bfloat16).is_paged: Selects betweenNonPagedKVCacheManager(contiguous allocation) andPagedKVCacheManager(page-based allocation).qkv_format: The incoming tensor format. The cache always stores in"bshd"format internally; if the input format differs, the output format is annotated (e.g.,"thd_2bshd").
Notes
- Since TransformerEngine 2.2,
num_heads_kv,head_dim_k, anddtypeare required parameters. - The
pre_step_seqlensbuffer is specifically designed for CUDA Graphs compatibility -- it is updated in-place using.copy_()to avoid pointer changes. - For paged mode,
total_num_pagesmust equalmax_batch_size * (max_sequence_length // page_size). - The
convert_paged_to_nonpaged()method enables interoperability between paged and non-paged attention backends when needed.