Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Sgl project Sglang Attention Ops

From Leeroopedia


Knowledge Sources
Domains Kernel, Attention, MLA
Last Updated 2026-02-10 00:00 GMT

Overview

Python interface for attention state merging and CUTLASS MLA (Multi-head Latent Attention) decode operations.

Description

The attention.py module provides four key functions. merge_state and merge_state_v2 merge two sets of attention values and log-sum-exp states, used for combining split-KV attention results; both convert log-sum-exp tensors to float32 for numerical stability before delegating to C++ ops. cutlass_mla_decode performs MLA decode for DeepSeek-style models, handling 512-dim latent and 64-dim rope embeddings, with head padding to 128 for hardware alignment. It validates tensor shapes and dtypes (fp16 or bf16), and uses paged KV cache with configurable KV splits. cutlass_mla_get_workspace_size computes the required workspace memory for MLA decode given sequence length, batch size, SM count, and KV split configuration.

Usage

Use merge_state or merge_state_v2 when combining attention results from split-KV computations (e.g., prefix caching, tensor parallelism). Use cutlass_mla_decode for high-throughput decoding in DeepSeek V2/V3 models with MLA attention.

Code Reference

Source Location

Signature

def merge_state(
    v_a: torch.Tensor,
    s_a: torch.Tensor,
    v_b: torch.Tensor,
    s_b: torch.Tensor,
    v_merged: Optional[torch.Tensor] = None,
    s_merged: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ...

def merge_state_v2(
    v_a: torch.Tensor,
    s_a: torch.Tensor,
    v_b: torch.Tensor,
    s_b: torch.Tensor,
    v_merged: Optional[torch.Tensor] = None,
    s_merged: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ...

def cutlass_mla_decode(
    q_nope: torch.Tensor,
    q_pe: torch.Tensor,
    kv_c_and_k_pe_cache: torch.Tensor,
    seq_lens: torch.Tensor,
    page_table: torch.Tensor,
    workspace: torch.Tensor,
    sm_scale: float,
    num_kv_splits: int = 1,
) -> torch.Tensor: ...

def cutlass_mla_get_workspace_size(
    max_seq_len: int,
    num_batches: int,
    sm_count: int = 0,
    num_kv_splits: int = 1,
) -> int: ...

Import

from sgl_kernel.attention import (
    cutlass_mla_decode,
    cutlass_mla_get_workspace_size,
    merge_state,
    merge_state_v2,
)

I/O Contract

Inputs

Name Type Required Description
v_a torch.Tensor Yes Attention values from first split
s_a torch.Tensor Yes Log-sum-exp states from first split
v_b torch.Tensor Yes Attention values from second split
s_b torch.Tensor Yes Log-sum-exp states from second split
v_merged Optional[torch.Tensor] No Pre-allocated output for merged values
s_merged Optional[torch.Tensor] No Pre-allocated output for merged states
q_nope torch.Tensor Yes (MLA) Query without positional encoding, shape [B, H, 512]
q_pe torch.Tensor Yes (MLA) Query positional encoding, shape [B, H, 64]
kv_c_and_k_pe_cache torch.Tensor Yes (MLA) Paged KV cache, shape [num_blocks, page_size, 576]
seq_lens torch.Tensor Yes (MLA) Sequence lengths per batch, int32
page_table torch.Tensor Yes (MLA) Block table mapping, shape [B, max_blocks], int32
workspace torch.Tensor Yes (MLA) Pre-allocated workspace buffer
sm_scale float Yes (MLA) Softmax scaling factor

Outputs

Name Type Description
v_merged torch.Tensor Merged attention values (merge_state/merge_state_v2)
s_merged torch.Tensor Merged log-sum-exp states (merge_state/merge_state_v2)
out torch.Tensor MLA decode output, shape [B, H, 512] (cutlass_mla_decode)
workspace_size int Required workspace bytes (cutlass_mla_get_workspace_size)

Usage Examples

from sgl_kernel.attention import merge_state, cutlass_mla_decode, cutlass_mla_get_workspace_size

# Merge split-KV attention results
v_merged, s_merged = merge_state(v_a, s_a, v_b, s_b)

# MLA decode
ws_size = cutlass_mla_get_workspace_size(max_seq_len=2048, num_batches=32)
workspace = torch.empty(ws_size, dtype=torch.uint8, device="cuda")
out = cutlass_mla_decode(
    q_nope, q_pe, kv_cache, seq_lens,
    page_table, workspace, sm_scale=0.125
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment