Implementation:Predibase Lorax FlashInfer Attention
| Knowledge Sources | |
|---|---|
| Domains | Attention, GPU_Kernels |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Provides state management and context managers for FlashInfer-based attention computation, supporting prefill (with ragged and paged KV caches) and decode phases with both standard and CUDA Graph execution paths.
Description
This module manages the lifecycle of FlashInfer attention wrappers using Python ContextVar objects, allowing different attention states to be set and retrieved within request processing scopes. It defines three global context variables:
- prefill_state: Holds a BatchPrefillWithRaggedKVCacheWrapper for prefill attention without paged KV.
- prefill_with_paged_kv_state: Holds a BatchPrefillWithPagedKVCacheWrapper for prefill attention using paged KV cache.
- decode_state: Holds a BatchDecodeWithPagedKVCacheWrapper for autoregressive decode attention.
get_workspace(device): Lazily allocates a shared 128 MB workspace buffer on the specified device, reused across all FlashInfer operations.
create_prefill_with_paged_kv_state(device): Creates a BatchPrefillWithPagedKVCacheWrapper with NHD layout and CUDA graphs disabled.
use_prefill_with_paged_kv_state(...): A context manager that sets up the prefill state for paged KV attention. It computes indptr (cumulative page indices from input lengths and page size), last_page_len (number of valid entries in the last page per sequence), and calls state.begin_forward() with all attention parameters. On exit, it calls state.end_forward() and resets the context variable.
create_prefill_state(device): Creates a BatchPrefillWithRaggedKVCacheWrapper for ragged (non-paged) KV cache prefill.
use_prefill_state(...): A context manager for ragged KV prefill attention. Uses cu_seqlens for both query/output and KV indirection pointers.
create_decode_state(device, num_heads, num_kv_heads): Creates a BatchDecodeWithPagedKVCacheWrapper. Enables tensor cores when the GQA ratio (num_heads / num_kv_heads) exceeds 4.
create_decode_state_cuda_graphs(...): Creates a decode state for use with CUDA Graphs. Pre-allocates block_tables, block_tables_ptr, and last_page_len buffers that are captured in the graph.
use_decode_state(...): A context manager for decode attention that computes paged KV indirection and last page lengths, then calls state.begin_forward() and state.end_forward() around the yielded block.
Usage
These functions are called by the model's attention layer implementations. During request processing, the appropriate create_* function is called once at initialization, and the corresponding use_* context manager is entered for each batch. Inside the context manager, attention operations read the active state from the ContextVar to perform the actual FlashInfer kernel calls.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/lorax_server/utils/flashinfer_attention.py - Lines: 1-229
Signature
def get_workspace(device) -> torch.Tensor
def create_prefill_with_paged_kv_state(*, device: torch.device)
def use_prefill_with_paged_kv_state(
*, state, block_tables, cu_seqlens, input_lengths, num_heads, num_kv_heads,
head_size, page_size, dtype, window_left,
)
def create_prefill_state(*, device: torch.device)
def use_prefill_state(
*, state, cu_seqlens, num_heads, num_kv_heads, head_size,
query_dtype="float16", window_left,
)
def create_decode_state(*, device, num_heads, num_kv_heads)
def create_decode_state_cuda_graphs(
*, device, block_tables, block_tables_ptr, last_page_len,
num_heads, num_kv_heads,
)
def use_decode_state(
*, state, input_lengths, block_tables, num_heads, num_kv_heads,
head_size, page_size, dtype, window_left,
)
Import
from lorax_server.utils.flashinfer_attention import (
create_prefill_with_paged_kv_state,
use_prefill_with_paged_kv_state,
create_decode_state,
use_decode_state,
prefill_state,
decode_state,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| device | torch.device | Yes | CUDA device for buffer allocation |
| block_tables | torch.Tensor | Yes (paged) | Block table indices mapping sequences to KV cache pages |
| cu_seqlens | torch.Tensor | Yes (prefill) | Cumulative sequence lengths for query/output indirection |
| input_lengths | torch.Tensor | Yes | Per-sequence input lengths |
| num_heads | int | Yes | Number of query attention heads |
| num_kv_heads | int | Yes | Number of key/value attention heads (for GQA) |
| head_size | int | Yes | Dimension of each attention head |
| page_size | int | Yes (paged) | Number of tokens per KV cache page |
| dtype | torch.dtype | Yes | Data type for attention computation |
| window_left | int | Yes | Sliding window attention size (0 for full attention) |
Outputs
| Name | Type | Description |
|---|---|---|
| state | FlashInfer wrapper | Configured FlashInfer batch attention wrapper ready for use within the context manager scope |
| workspace | torch.Tensor (uint8) | Shared 128 MB GPU workspace buffer |
Usage Examples
# Creating and using FlashInfer decode attention state
from lorax_server.utils.flashinfer_attention import (
create_decode_state,
use_decode_state,
)
state = create_decode_state(device=device, num_heads=32, num_kv_heads=8)
with use_decode_state(
state=state,
input_lengths=input_lengths,
block_tables=block_tables,
num_heads=32,
num_kv_heads=8,
head_size=128,
page_size=16,
dtype=torch.float16,
window_left=0,
):
# Attention operations use the active decode_state ContextVar
output = attention(query, key, value)