Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mlc ai Mlc llm Attention Op

From Leeroopedia


Overview

The Attention Op module provides the main attention operator used during model compilation, with support for FlashInfer-accelerated execution and a TIR-based fallback. It is located at python/mlc_llm/op/attention.py (179 lines).

The attention function handles causal masked attention for both prefill (sequence-level) and decode (single-token) phases. When FlashInfer is available and the model configuration meets its requirements, the function dispatches to highly optimized external FlashInfer kernels. Otherwise, it falls back to a TVM TIR-based implementation.

Source File

  • File: python/mlc_llm/op/attention.py
  • Lines: 179
  • Module: mlc_llm.op.attention

Dependencies

Import Purpose
tvm Access to TVM target configuration
tvm.relax.frontend.nn Neural network tensor types and operations
mlc_llm.support.logging Logger for FlashInfer compatibility warnings
mlc_llm.op.extern External module store (tracks whether FlashInfer is available)

Global Warning Flags

WARN_FLASHINFER_GROUP_SIZE = False
WARN_FLASHINFER_HEAD_DIM = False

These module-level flags ensure that FlashInfer compatibility warnings are emitted only once during a session, preventing repetitive log messages.

Function: attention

def attention(
    q: nn.Tensor,
    k: nn.Tensor,
    v: nn.Tensor,
    casual_mask: nn.Tensor,
    attn_score_scaling_factor: float = 1.0,
    qk_dtype: str = None,
) -> nn.Tensor:

Tensor Shapes

Tensor Shape Description
q [b, s, h_q, d] Query tensor (batch, sequence length, query heads, head dim)
k [t, h_kv, d] or [b, t, h_kv, d] Key tensor (total sequence length, KV heads, head dim)
v [t, h_kv, d] or [b, t, h_kv, d] Value tensor (same shape as K)
casual_mask varies Causal attention mask (can be None for full attention)
output [1, s, h_q * d] Reshaped attention output

Computation

The attention computation follows the standard scaled dot-product attention:

if h_kv != h_q:
    k = k.repeat(h_q // h_kv, axis=1)   # expand for GQA
    v = v.repeat(h_q // h_kv, axis=1)
attn = q @ k^T / sqrt(d) * attn_score_scaling_factor
attn = softmax_with_mask(attn, casual_mask, axis=-1)
output = attn @ v
output -> [b, s, h_q * d]               # reshape

FlashInfer Path

FlashInfer is used when all of the following conditions are met:

  1. FlashInfer is available (_extern.get_store().flashinfer is truthy)
  2. attn_score_scaling_factor == 1.0
  3. All of Q, K, V have dtype "float16"
  4. The group size (h_q // h_kv) is in [1, 4, 6, 8]
  5. The head dimension d is 128
if (
    _extern.get_store().flashinfer
    and attn_score_scaling_factor == 1.0
    and q.dtype == "float16"
    and k.dtype == "float16"
    and v.dtype == "float16"
):

When FlashInfer conditions are met, the function selects between decode and prefill modes:

if isinstance(s, int) and s == 1:
    func = "decode"     # single-token decode
else:
    func = "prefill"    # multi-token prefill

Decode mode calls flashinfer.single_decode:

return op.extern(
    name="flashinfer.single_decode",
    args=[q, k, v, scratch, qkv_layout, rotary_mode, rope_scale, rope_theta],
    out=nn.Tensor.placeholder((b, s, h_q * d), dtype="float16"),
)

Prefill mode calls flashinfer.single_prefill with additional parameters for causal masking and FP16 QK reduction:

return op.extern(
    name="flashinfer.single_prefill",
    args=[q, k, v, scratch, casual, qkv_layout, rotary_mode, fp16_qk, rope_scale, rope_theta],
    out=nn.Tensor.placeholder((b, s, h_q * d), dtype="float16"),
)

Both modes allocate a 32MB scratchpad buffer (8192 * 1024 float32 elements) for FlashInfer's internal workspace.

FlashInfer Configuration Constants

Constant Value Meaning
qkv_layout 0 "NHD" layout: sequence length, num_heads, head_dim
rotary_mode 0 "kNone" -- no rotary embedding applied inside attention
casual 1 Causal masking enabled
fp16_qk 1 (or 0) Enable FP16 QK reduction; disabled when qk_dtype == "float32"

Fallback Path

When FlashInfer cannot be used, the function falls back to TVM's TIR-based attention:

def _fallback():
    from tvm.relax.frontend.nn.llm.kv_cache import _attention_sequence_prefill

    # Reshape K, V from 3D to 4D if needed
    if k.ndim == 3:
        k = op.reshape(k, [b, t, h_kv, d])
    if v.ndim == 3:
        v = op.reshape(v, [b, t, h_kv, d])
    # Expand KV heads for GQA
    if h_kv != h_q:
        k = k.repeat(h_q // h_kv, axis=2)
        v = v.repeat(h_q // h_kv, axis=2)

    target = tvm.target.Target("cuda")
    attn_output, _ = op.tensor_ir_op(
        _attention_sequence_prefill(h_kv=h_kv, h_q=h_q, d=d, dtype=q.dtype, target=target),
        "sequence_prefill",
        [q, k, v],
        [
            Tensor.placeholder([b, s, h_q, d], q.dtype),
            Tensor.placeholder([b, s, h_q], q.dtype),
        ],
    )
    output = op.reshape(attn_output, shape=(b, s, h_q * d))
    return output

The fallback:

  1. Reshapes 3D K/V tensors to 4D.
  2. Repeats KV heads for grouped-query attention (GQA) support.
  3. Calls _attention_sequence_prefill as a TIR operator targeting CUDA.
  4. Reshapes the output from [b, s, h_q, d] to [b, s, h_q * d].

Design Notes

  • The function supports grouped-query attention (GQA) where h_q != h_kv, with the group size being h_q // h_kv.
  • The qk_dtype parameter controls numerical precision for the QK matmul: when set to "float32", FlashInfer disables FP16 QK reduction for improved accuracy.
  • FlashInfer unsupported configurations silently fall back to the TIR implementation after a one-time warning.

Categories

  • Attention Operator
  • FlashInfer Integration
  • TVM TIR
  • GPU Kernels
  • Grouped-Query Attention

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment