Implementation:Sgl project Sglang Attention Ops
| Knowledge Sources | |
|---|---|
| Domains | Kernel, Attention, MLA |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Python interface for attention state merging and CUTLASS MLA (Multi-head Latent Attention) decode operations.
Description
The attention.py module provides four key functions. merge_state and merge_state_v2 merge two sets of attention values and log-sum-exp states, used for combining split-KV attention results; both convert log-sum-exp tensors to float32 for numerical stability before delegating to C++ ops. cutlass_mla_decode performs MLA decode for DeepSeek-style models, handling 512-dim latent and 64-dim rope embeddings, with head padding to 128 for hardware alignment. It validates tensor shapes and dtypes (fp16 or bf16), and uses paged KV cache with configurable KV splits. cutlass_mla_get_workspace_size computes the required workspace memory for MLA decode given sequence length, batch size, SM count, and KV split configuration.
Usage
Use merge_state or merge_state_v2 when combining attention results from split-KV computations (e.g., prefix caching, tensor parallelism). Use cutlass_mla_decode for high-throughput decoding in DeepSeek V2/V3 models with MLA attention.
Code Reference
Source Location
- Repository: Sgl_project_Sglang
- File: sgl-kernel/python/sgl_kernel/attention.py
- Lines: 1-133
Signature
def merge_state(
v_a: torch.Tensor,
s_a: torch.Tensor,
v_b: torch.Tensor,
s_b: torch.Tensor,
v_merged: Optional[torch.Tensor] = None,
s_merged: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ...
def merge_state_v2(
v_a: torch.Tensor,
s_a: torch.Tensor,
v_b: torch.Tensor,
s_b: torch.Tensor,
v_merged: Optional[torch.Tensor] = None,
s_merged: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ...
def cutlass_mla_decode(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
seq_lens: torch.Tensor,
page_table: torch.Tensor,
workspace: torch.Tensor,
sm_scale: float,
num_kv_splits: int = 1,
) -> torch.Tensor: ...
def cutlass_mla_get_workspace_size(
max_seq_len: int,
num_batches: int,
sm_count: int = 0,
num_kv_splits: int = 1,
) -> int: ...
Import
from sgl_kernel.attention import (
cutlass_mla_decode,
cutlass_mla_get_workspace_size,
merge_state,
merge_state_v2,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| v_a | torch.Tensor | Yes | Attention values from first split |
| s_a | torch.Tensor | Yes | Log-sum-exp states from first split |
| v_b | torch.Tensor | Yes | Attention values from second split |
| s_b | torch.Tensor | Yes | Log-sum-exp states from second split |
| v_merged | Optional[torch.Tensor] | No | Pre-allocated output for merged values |
| s_merged | Optional[torch.Tensor] | No | Pre-allocated output for merged states |
| q_nope | torch.Tensor | Yes (MLA) | Query without positional encoding, shape [B, H, 512] |
| q_pe | torch.Tensor | Yes (MLA) | Query positional encoding, shape [B, H, 64] |
| kv_c_and_k_pe_cache | torch.Tensor | Yes (MLA) | Paged KV cache, shape [num_blocks, page_size, 576] |
| seq_lens | torch.Tensor | Yes (MLA) | Sequence lengths per batch, int32 |
| page_table | torch.Tensor | Yes (MLA) | Block table mapping, shape [B, max_blocks], int32 |
| workspace | torch.Tensor | Yes (MLA) | Pre-allocated workspace buffer |
| sm_scale | float | Yes (MLA) | Softmax scaling factor |
Outputs
| Name | Type | Description |
|---|---|---|
| v_merged | torch.Tensor | Merged attention values (merge_state/merge_state_v2) |
| s_merged | torch.Tensor | Merged log-sum-exp states (merge_state/merge_state_v2) |
| out | torch.Tensor | MLA decode output, shape [B, H, 512] (cutlass_mla_decode) |
| workspace_size | int | Required workspace bytes (cutlass_mla_get_workspace_size) |
Usage Examples
from sgl_kernel.attention import merge_state, cutlass_mla_decode, cutlass_mla_get_workspace_size
# Merge split-KV attention results
v_merged, s_merged = merge_state(v_a, s_a, v_b, s_b)
# MLA decode
ws_size = cutlass_mla_get_workspace_size(max_seq_len=2048, num_batches=32)
workspace = torch.empty(ws_size, dtype=torch.uint8, device="cuda")
out = cutlass_mla_decode(
q_nope, q_pe, kv_cache, seq_lens,
page_table, workspace, sm_scale=0.125
)