Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:NVIDIA TransformerEngine InferenceParams

From Leeroopedia


Overview

Concrete tool for managing KV caches during autoregressive inference provided by TransformerEngine.

Description

InferenceParams manages KV cache allocation, population, and retrieval across Transformer layers during inference. It supports both paged and non-paged cache modes, configurable QKV formats (bshd, sbhd, thd), and per-layer cache allocation. The class delegates actual cache storage and manipulation to a KVCacheManager implementation -- either NonPagedKVCacheManager or PagedKVCacheManager -- selected based on the is_paged parameter.

The class maintains cumulative sequence length tensors (cu_seqlens_q and cu_seqlens_kv) that are updated during pre_step() and used by fused attention kernels. It also tracks a pre_step_seqlens buffer containing the running length of each sequence before the current step, which is used for applying RoPE embeddings during inference.

The integration point within the TE model hierarchy is:

TransformerLayer
  -> MultiHeadAttention
       if layer_number not in inference_params.cache_manager.cache:
           inference_params.allocate_memory(layer_number)
       -> DotProductAttention
            if inference_params is not None:
                k_cache, v_cache, ... = inference_params.step(layer_number, new_k, new_v, qkv_format)
            output = attention(new_q, k_cache, v_cache, ...)

Source

transformer_engine/pytorch/attention/inference.py, class InferenceParams at lines 55-228, __init__ at lines 127-228.

Import

from transformer_engine.pytorch.attention import InferenceParams

Signature

class InferenceParams:
    def __init__(
        self,
        max_batch_size: int,
        max_sequence_length: int,
        num_heads_kv: int = None,
        head_dim_k: int = None,
        dtype: torch.dtype = None,
        head_dim_v: int = None,
        is_paged: bool = False,
        total_num_pages: int = None,
        page_size: int = None,
        max_ctx_len: int = None,
        qkv_format: str = "bshd",
        custom_cache_manager: KVCacheManager = None,
    ):
        ...

    def reset(self):
        """Reset InferenceParams state"""
        ...

    def allocate_memory(self, layer_number: int):
        """Allocate KV cache memory for a specific layer"""
        ...

    def pre_step(self, step_dict: OrderedDict):
        """Update tracked sequences and prepare cumulative seqlens for step()"""
        ...

    def step(
        self,
        layer_number: int,
        new_k: torch.Tensor,
        new_v: torch.Tensor,
        qkv_format: str,
    ) -> tuple:
        """Copy new KV tokens to cache and return full cache tensors"""
        ...

    def get_seqlens_pre_step(self) -> torch.Tensor:
        """Get cached sequence lengths before the stepping"""
        ...

    def convert_paged_to_nonpaged(self, layer_number: int) -> tuple:
        """Convert paged KV cache to non-paged format for a layer"""
        ...

I/O

Constructor Input:

Parameter Type Default Description
max_batch_size int required Maximum batch size during inference
max_sequence_length int required Maximum sequence length during inference
num_heads_kv int None Number of KV attention heads (required since TE 2.2)
head_dim_k int None Head dimension for keys (required since TE 2.2)
dtype torch.dtype None Data type for KV cache tensors (required since TE 2.2)
head_dim_v int None Head dimension for values; defaults to head_dim_k
is_paged bool False Whether to use paged KV cache
total_num_pages int None Total pages for paged cache; required when is_paged=True
page_size int None Page size for paged cache; required when is_paged=True
max_ctx_len int None Maximum context length; required when qkv_format="thd"
qkv_format str "bshd" Format of incoming QKV tensors: "bshd", "sbhd", or "thd"
custom_cache_manager KVCacheManager None Custom cache manager subclass

pre_step() Input:

  • step_dict: OrderedDict[int, int] -- Mapping of sequence IDs to step lengths for the current iteration

step() Input:

  • layer_number: int -- Layer number for cache lookup
  • new_k: torch.Tensor -- New key tokens for current iteration
  • new_v: torch.Tensor -- New value tokens for current iteration
  • qkv_format: str -- Format of the new K/V tensors

step() Output:

  • k_cache: torch.Tensor -- Full key cache including new tokens
  • v_cache: torch.Tensor -- Full value cache including new tokens
  • cu_seqlens_q: torch.Tensor -- Cumulative query sequence lengths [batch_size + 1]
  • cu_seqlens_kv: torch.Tensor -- Cumulative KV sequence lengths [batch_size + 1]
  • max_sequence_length: int -- Maximum sequence length
  • qkv_format: str -- Updated QKV format (e.g. "thd" becomes "thd_2bshd")

Key Parameters

  • max_batch_size: Determines the size of pre-allocated cumulative sequence length buffers and the maximum batch dimension of the KV cache.
  • max_sequence_length: Determines the sequence dimension of the KV cache. For paged mode, must be divisible by page_size.
  • num_heads_kv: Number of key-value attention heads, supporting grouped-query attention configurations.
  • head_dim_k: Per-head dimension for keys; head_dim_v defaults to the same value if not specified.
  • dtype: Data type for the cache tensors (typically torch.bfloat16).
  • is_paged: Selects between NonPagedKVCacheManager (contiguous allocation) and PagedKVCacheManager (page-based allocation).
  • qkv_format: The incoming tensor format. The cache always stores in "bshd" format internally; if the input format differs, the output format is annotated (e.g., "thd_2bshd").

Notes

  • Since TransformerEngine 2.2, num_heads_kv, head_dim_k, and dtype are required parameters.
  • The pre_step_seqlens buffer is specifically designed for CUDA Graphs compatibility -- it is updated in-place using .copy_() to avoid pointer changes.
  • For paged mode, total_num_pages must equal max_batch_size * (max_sequence_length // page_size).
  • The convert_paged_to_nonpaged() method enables interoperability between paged and non-paged attention backends when needed.

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment