Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax FlashInfer Attention

From Leeroopedia


Knowledge Sources
Domains Attention, GPU_Kernels
Last Updated 2026-02-08 00:00 GMT

Overview

Provides state management and context managers for FlashInfer-based attention computation, supporting prefill (with ragged and paged KV caches) and decode phases with both standard and CUDA Graph execution paths.

Description

This module manages the lifecycle of FlashInfer attention wrappers using Python ContextVar objects, allowing different attention states to be set and retrieved within request processing scopes. It defines three global context variables:

  • prefill_state: Holds a BatchPrefillWithRaggedKVCacheWrapper for prefill attention without paged KV.
  • prefill_with_paged_kv_state: Holds a BatchPrefillWithPagedKVCacheWrapper for prefill attention using paged KV cache.
  • decode_state: Holds a BatchDecodeWithPagedKVCacheWrapper for autoregressive decode attention.

get_workspace(device): Lazily allocates a shared 128 MB workspace buffer on the specified device, reused across all FlashInfer operations.

create_prefill_with_paged_kv_state(device): Creates a BatchPrefillWithPagedKVCacheWrapper with NHD layout and CUDA graphs disabled.

use_prefill_with_paged_kv_state(...): A context manager that sets up the prefill state for paged KV attention. It computes indptr (cumulative page indices from input lengths and page size), last_page_len (number of valid entries in the last page per sequence), and calls state.begin_forward() with all attention parameters. On exit, it calls state.end_forward() and resets the context variable.

create_prefill_state(device): Creates a BatchPrefillWithRaggedKVCacheWrapper for ragged (non-paged) KV cache prefill.

use_prefill_state(...): A context manager for ragged KV prefill attention. Uses cu_seqlens for both query/output and KV indirection pointers.

create_decode_state(device, num_heads, num_kv_heads): Creates a BatchDecodeWithPagedKVCacheWrapper. Enables tensor cores when the GQA ratio (num_heads / num_kv_heads) exceeds 4.

create_decode_state_cuda_graphs(...): Creates a decode state for use with CUDA Graphs. Pre-allocates block_tables, block_tables_ptr, and last_page_len buffers that are captured in the graph.

use_decode_state(...): A context manager for decode attention that computes paged KV indirection and last page lengths, then calls state.begin_forward() and state.end_forward() around the yielded block.

Usage

These functions are called by the model's attention layer implementations. During request processing, the appropriate create_* function is called once at initialization, and the corresponding use_* context manager is entered for each batch. Inside the context manager, attention operations read the active state from the ContextVar to perform the actual FlashInfer kernel calls.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/utils/flashinfer_attention.py
  • Lines: 1-229

Signature

def get_workspace(device) -> torch.Tensor

def create_prefill_with_paged_kv_state(*, device: torch.device)
def use_prefill_with_paged_kv_state(
    *, state, block_tables, cu_seqlens, input_lengths, num_heads, num_kv_heads,
    head_size, page_size, dtype, window_left,
)

def create_prefill_state(*, device: torch.device)
def use_prefill_state(
    *, state, cu_seqlens, num_heads, num_kv_heads, head_size,
    query_dtype="float16", window_left,
)

def create_decode_state(*, device, num_heads, num_kv_heads)
def create_decode_state_cuda_graphs(
    *, device, block_tables, block_tables_ptr, last_page_len,
    num_heads, num_kv_heads,
)
def use_decode_state(
    *, state, input_lengths, block_tables, num_heads, num_kv_heads,
    head_size, page_size, dtype, window_left,
)

Import

from lorax_server.utils.flashinfer_attention import (
    create_prefill_with_paged_kv_state,
    use_prefill_with_paged_kv_state,
    create_decode_state,
    use_decode_state,
    prefill_state,
    decode_state,
)

I/O Contract

Inputs

Name Type Required Description
device torch.device Yes CUDA device for buffer allocation
block_tables torch.Tensor Yes (paged) Block table indices mapping sequences to KV cache pages
cu_seqlens torch.Tensor Yes (prefill) Cumulative sequence lengths for query/output indirection
input_lengths torch.Tensor Yes Per-sequence input lengths
num_heads int Yes Number of query attention heads
num_kv_heads int Yes Number of key/value attention heads (for GQA)
head_size int Yes Dimension of each attention head
page_size int Yes (paged) Number of tokens per KV cache page
dtype torch.dtype Yes Data type for attention computation
window_left int Yes Sliding window attention size (0 for full attention)

Outputs

Name Type Description
state FlashInfer wrapper Configured FlashInfer batch attention wrapper ready for use within the context manager scope
workspace torch.Tensor (uint8) Shared 128 MB GPU workspace buffer

Usage Examples

# Creating and using FlashInfer decode attention state
from lorax_server.utils.flashinfer_attention import (
    create_decode_state,
    use_decode_state,
)

state = create_decode_state(device=device, num_heads=32, num_kv_heads=8)

with use_decode_state(
    state=state,
    input_lengths=input_lengths,
    block_tables=block_tables,
    num_heads=32,
    num_kv_heads=8,
    head_size=128,
    page_size=16,
    dtype=torch.float16,
    window_left=0,
):
    # Attention operations use the active decode_state ContextVar
    output = attention(query, key, value)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment