Implementation:Sgl project Sglang Flash Attention Interface
| Knowledge Sources | |
|---|---|
| Domains | Flash Attention, GPU Kernels, LLM Inference |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Python interface for Flash Attention 3 operations with KV cache support, serving as the primary attention backend for LLM inference with dispatch to FA4 when requested.
Description
flash_attn.py is the primary attention implementation for SGLang inference, providing two main functions that wrap the underlying torch.ops.sgl_kernel.fwd custom op.
flash_attn_with_kvcache supports full incremental decoding with comprehensive feature coverage:
- In-place KV cache updates: When k and v are provided alongside k_cache and v_cache, the cache is updated in-place starting at positions specified by cache_seqlens, enabling efficient incremental decoding.
- Paged KV cache: Via page_table parameter (int32 tensor mapping batch indices to physical page blocks). The KV cache is organized as (num_blocks, page_block_size, nheads_k, head_dim).
- Rotary embeddings: Integrated rotary position encoding via rotary_cos/rotary_sin with configurable interleaving.
- GQA/MQA: Full support for grouped-query and multi-query attention where Q heads must be divisible by KV heads.
- Causal and sliding window masking: Causal mask aligned to bottom-right corner, plus configurable sliding window via window_size=(left, right) where -1 means infinite context.
- FP8 descaling: q_descale, k_descale, v_descale for FP8 quantized attention.
- Split-KV scheduling: num_splits controls KV splitting for improved parallelism (0 = heuristic auto-tuning).
- Attention chunking: Optional chunking for memory savings on long sequences.
- FA4 dispatch: When ver=4, dispatches to flash_attn_varlen_func_v4 from the FA4 interface module. FA4 does not support in-place KV updates, rotary embedding, or descaling.
flash_attn_varlen_func handles variable-length batched attention:
- Requires cu_seqlens_q, cu_seqlens_k, max_seqlen_q, and max_seqlen_k for FA3.
- Supports all the same features as the KV cache variant except paged KV and rotary embedding.
- Can dispatch to FA4 via ver=4 parameter.
is_fa3_supported is a cached function that checks GPU compatibility: requires CUDA >= 12.3 and SM capability 8.x or 9.x (A100, A*0, L20, L40/L40s, RTX 4090, H100, etc.).
The module attempts to import FA4 (flash_attn_varlen_func_v4) from _fa4_interface but gracefully falls back to None if unavailable, allowing FA3-only operation.
Usage
Use flash_attn_with_kvcache for autoregressive decoding where the KV cache is maintained across steps. Use flash_attn_varlen_func for prefill or variable-length batch processing. Set ver=4 to use FA4 on Hopper/Blackwell GPUs for potentially better performance.
Code Reference
Source Location
- Repository: Sgl_project_Sglang
- File: sgl-kernel/python/sgl_kernel/flash_attn.py
- Lines: 1-381
Signature
def is_fa3_supported(device=None) -> bool:
def flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k=None,
v=None,
qv=None,
rotary_cos=None,
rotary_sin=None,
cache_seqlens: Optional[Union[int, torch.Tensor]] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k_new: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
rotary_seqlens: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
attention_chunk: Optional[int] = None,
softcap=0.0,
rotary_interleaved=True,
scheduler_metadata=None,
num_splits=0,
pack_gqa=None,
sm_margin=0,
return_softmax_lse=False,
sinks=None,
score_mod=None,
aux_tensors=None,
ver=3,
):
def flash_attn_varlen_func(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
max_seqlen_q=None, max_seqlen_k=None,
seqused_q=None, seqused_k=None,
page_table=None,
softmax_scale=None,
causal=False,
qv=None,
q_descale=None, k_descale=None, v_descale=None,
window_size=(-1, -1),
attention_chunk=0,
softcap=0.0,
num_splits=1,
pack_gqa=None,
sm_margin=0,
return_softmax_lse=False,
sinks=None,
score_mod=None,
aux_tensors=None,
ver=3,
):
Import
from sgl_kernel import flash_attn_with_kvcache, flash_attn_varlen_func
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| q | torch.Tensor | Yes | Query tensor: (batch_size, seqlen, nheads, headdim) |
| k_cache | torch.Tensor | Yes | Key cache: (batch_size, seqlen_cache, nheads_k, headdim) or paged |
| v_cache | torch.Tensor | Yes | Value cache: same layout as k_cache with headdim_v |
| k, v | torch.Tensor | No | New keys/values for cache update: (batch_size, seqlen_new, nheads_k, headdim) |
| cache_seqlens | int or torch.Tensor (int32) | No | Current KV cache sequence lengths per batch |
| page_table | torch.Tensor (int32) | No | Page table for paged KV cache: (batch_size, max_num_blocks) |
| cu_seqlens_q | torch.Tensor (int32) | No | Cumulative query sequence lengths |
| softmax_scale | float | No | QK scaling, defaults to headdim^(-0.5) |
| causal | bool | No | Enable causal attention mask |
| window_size | Tuple[int, int] | No | Sliding window (left, right), -1 = infinite |
| softcap | float | No | Softcapping value, 0.0 = disabled |
| num_splits | int | No | KV split count, 0 = heuristic |
| ver | int | No | FA version: 3 (default) or 4 |
Outputs
| Name | Type | Description |
|---|---|---|
| out | torch.Tensor | Attention output: (batch_size, seqlen, nheads, headdim) |
| softmax_lse | torch.Tensor (float32) | Log-sum-exp values: (batch_size, nheads, seqlen), only if return_softmax_lse=True |
Usage Examples
from sgl_kernel import flash_attn_with_kvcache, flash_attn_varlen_func
# Incremental decoding with KV cache
out = flash_attn_with_kvcache(
q=query, # (batch, 1, nheads, headdim)
k_cache=k_cache, # (batch, max_seqlen, nheads_k, headdim)
v_cache=v_cache,
k=new_k, # (batch, 1, nheads_k, headdim)
v=new_v,
cache_seqlens=seqlens, # (batch,), int32
causal=True,
)
# Paged KV cache decoding
out = flash_attn_with_kvcache(
q=query,
k_cache=paged_k, # (num_blocks, page_size, nheads_k, headdim)
v_cache=paged_v,
page_table=page_tbl, # (batch, max_num_blocks), int32
cache_seqlens=seqlens,
causal=True,
num_splits=0, # auto-tune splits
)
# Variable-length prefill with FA4
out = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_q,
cu_seqlens_k=cu_k,
max_seqlen_q=max_q,
max_seqlen_k=max_k,
causal=True,
ver=4, # use FA4
)