Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Sgl project Sglang Flash Attention Interface

From Leeroopedia


Knowledge Sources
Domains Flash Attention, GPU Kernels, LLM Inference
Last Updated 2026-02-10 00:00 GMT

Overview

Python interface for Flash Attention 3 operations with KV cache support, serving as the primary attention backend for LLM inference with dispatch to FA4 when requested.

Description

flash_attn.py is the primary attention implementation for SGLang inference, providing two main functions that wrap the underlying torch.ops.sgl_kernel.fwd custom op.

flash_attn_with_kvcache supports full incremental decoding with comprehensive feature coverage:

  • In-place KV cache updates: When k and v are provided alongside k_cache and v_cache, the cache is updated in-place starting at positions specified by cache_seqlens, enabling efficient incremental decoding.
  • Paged KV cache: Via page_table parameter (int32 tensor mapping batch indices to physical page blocks). The KV cache is organized as (num_blocks, page_block_size, nheads_k, head_dim).
  • Rotary embeddings: Integrated rotary position encoding via rotary_cos/rotary_sin with configurable interleaving.
  • GQA/MQA: Full support for grouped-query and multi-query attention where Q heads must be divisible by KV heads.
  • Causal and sliding window masking: Causal mask aligned to bottom-right corner, plus configurable sliding window via window_size=(left, right) where -1 means infinite context.
  • FP8 descaling: q_descale, k_descale, v_descale for FP8 quantized attention.
  • Split-KV scheduling: num_splits controls KV splitting for improved parallelism (0 = heuristic auto-tuning).
  • Attention chunking: Optional chunking for memory savings on long sequences.
  • FA4 dispatch: When ver=4, dispatches to flash_attn_varlen_func_v4 from the FA4 interface module. FA4 does not support in-place KV updates, rotary embedding, or descaling.

flash_attn_varlen_func handles variable-length batched attention:

  • Requires cu_seqlens_q, cu_seqlens_k, max_seqlen_q, and max_seqlen_k for FA3.
  • Supports all the same features as the KV cache variant except paged KV and rotary embedding.
  • Can dispatch to FA4 via ver=4 parameter.

is_fa3_supported is a cached function that checks GPU compatibility: requires CUDA >= 12.3 and SM capability 8.x or 9.x (A100, A*0, L20, L40/L40s, RTX 4090, H100, etc.).

The module attempts to import FA4 (flash_attn_varlen_func_v4) from _fa4_interface but gracefully falls back to None if unavailable, allowing FA3-only operation.

Usage

Use flash_attn_with_kvcache for autoregressive decoding where the KV cache is maintained across steps. Use flash_attn_varlen_func for prefill or variable-length batch processing. Set ver=4 to use FA4 on Hopper/Blackwell GPUs for potentially better performance.

Code Reference

Source Location

Signature

def is_fa3_supported(device=None) -> bool:

def flash_attn_with_kvcache(
    q,
    k_cache,
    v_cache,
    k=None,
    v=None,
    qv=None,
    rotary_cos=None,
    rotary_sin=None,
    cache_seqlens: Optional[Union[int, torch.Tensor]] = None,
    cache_batch_idx: Optional[torch.Tensor] = None,
    cache_leftpad: Optional[torch.Tensor] = None,
    page_table: Optional[torch.Tensor] = None,
    cu_seqlens_q: Optional[torch.Tensor] = None,
    cu_seqlens_k_new: Optional[torch.Tensor] = None,
    max_seqlen_q: Optional[int] = None,
    rotary_seqlens: Optional[torch.Tensor] = None,
    q_descale: Optional[torch.Tensor] = None,
    k_descale: Optional[torch.Tensor] = None,
    v_descale: Optional[torch.Tensor] = None,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),
    attention_chunk: Optional[int] = None,
    softcap=0.0,
    rotary_interleaved=True,
    scheduler_metadata=None,
    num_splits=0,
    pack_gqa=None,
    sm_margin=0,
    return_softmax_lse=False,
    sinks=None,
    score_mod=None,
    aux_tensors=None,
    ver=3,
):

def flash_attn_varlen_func(
    q, k, v,
    cu_seqlens_q, cu_seqlens_k,
    max_seqlen_q=None, max_seqlen_k=None,
    seqused_q=None, seqused_k=None,
    page_table=None,
    softmax_scale=None,
    causal=False,
    qv=None,
    q_descale=None, k_descale=None, v_descale=None,
    window_size=(-1, -1),
    attention_chunk=0,
    softcap=0.0,
    num_splits=1,
    pack_gqa=None,
    sm_margin=0,
    return_softmax_lse=False,
    sinks=None,
    score_mod=None,
    aux_tensors=None,
    ver=3,
):

Import

from sgl_kernel import flash_attn_with_kvcache, flash_attn_varlen_func

I/O Contract

Inputs

Name Type Required Description
q torch.Tensor Yes Query tensor: (batch_size, seqlen, nheads, headdim)
k_cache torch.Tensor Yes Key cache: (batch_size, seqlen_cache, nheads_k, headdim) or paged
v_cache torch.Tensor Yes Value cache: same layout as k_cache with headdim_v
k, v torch.Tensor No New keys/values for cache update: (batch_size, seqlen_new, nheads_k, headdim)
cache_seqlens int or torch.Tensor (int32) No Current KV cache sequence lengths per batch
page_table torch.Tensor (int32) No Page table for paged KV cache: (batch_size, max_num_blocks)
cu_seqlens_q torch.Tensor (int32) No Cumulative query sequence lengths
softmax_scale float No QK scaling, defaults to headdim^(-0.5)
causal bool No Enable causal attention mask
window_size Tuple[int, int] No Sliding window (left, right), -1 = infinite
softcap float No Softcapping value, 0.0 = disabled
num_splits int No KV split count, 0 = heuristic
ver int No FA version: 3 (default) or 4

Outputs

Name Type Description
out torch.Tensor Attention output: (batch_size, seqlen, nheads, headdim)
softmax_lse torch.Tensor (float32) Log-sum-exp values: (batch_size, nheads, seqlen), only if return_softmax_lse=True

Usage Examples

from sgl_kernel import flash_attn_with_kvcache, flash_attn_varlen_func

# Incremental decoding with KV cache
out = flash_attn_with_kvcache(
    q=query,           # (batch, 1, nheads, headdim)
    k_cache=k_cache,   # (batch, max_seqlen, nheads_k, headdim)
    v_cache=v_cache,
    k=new_k,           # (batch, 1, nheads_k, headdim)
    v=new_v,
    cache_seqlens=seqlens,  # (batch,), int32
    causal=True,
)

# Paged KV cache decoding
out = flash_attn_with_kvcache(
    q=query,
    k_cache=paged_k,     # (num_blocks, page_size, nheads_k, headdim)
    v_cache=paged_v,
    page_table=page_tbl,  # (batch, max_num_blocks), int32
    cache_seqlens=seqlens,
    causal=True,
    num_splits=0,  # auto-tune splits
)

# Variable-length prefill with FA4
out = flash_attn_varlen_func(
    q, k, v,
    cu_seqlens_q=cu_q,
    cu_seqlens_k=cu_k,
    max_seqlen_q=max_q,
    max_seqlen_k=max_k,
    causal=True,
    ver=4,  # use FA4
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment