Implementation:Mlc ai Mlc llm Attention Op
Overview
The Attention Op module provides the main attention operator used during model compilation, with support for FlashInfer-accelerated execution and a TIR-based fallback. It is located at python/mlc_llm/op/attention.py (179 lines).
The attention function handles causal masked attention for both prefill (sequence-level) and decode (single-token) phases. When FlashInfer is available and the model configuration meets its requirements, the function dispatches to highly optimized external FlashInfer kernels. Otherwise, it falls back to a TVM TIR-based implementation.
Source File
- File:
python/mlc_llm/op/attention.py - Lines: 179
- Module:
mlc_llm.op.attention
Dependencies
| Import | Purpose |
|---|---|
tvm |
Access to TVM target configuration |
tvm.relax.frontend.nn |
Neural network tensor types and operations |
mlc_llm.support.logging |
Logger for FlashInfer compatibility warnings |
mlc_llm.op.extern |
External module store (tracks whether FlashInfer is available) |
Global Warning Flags
WARN_FLASHINFER_GROUP_SIZE = False
WARN_FLASHINFER_HEAD_DIM = False
These module-level flags ensure that FlashInfer compatibility warnings are emitted only once during a session, preventing repetitive log messages.
Function: attention
def attention(
q: nn.Tensor,
k: nn.Tensor,
v: nn.Tensor,
casual_mask: nn.Tensor,
attn_score_scaling_factor: float = 1.0,
qk_dtype: str = None,
) -> nn.Tensor:
Tensor Shapes
| Tensor | Shape | Description |
|---|---|---|
q |
[b, s, h_q, d] |
Query tensor (batch, sequence length, query heads, head dim) |
k |
[t, h_kv, d] or [b, t, h_kv, d] |
Key tensor (total sequence length, KV heads, head dim) |
v |
[t, h_kv, d] or [b, t, h_kv, d] |
Value tensor (same shape as K) |
casual_mask |
varies | Causal attention mask (can be None for full attention)
|
| output | [1, s, h_q * d] |
Reshaped attention output |
Computation
The attention computation follows the standard scaled dot-product attention:
if h_kv != h_q:
k = k.repeat(h_q // h_kv, axis=1) # expand for GQA
v = v.repeat(h_q // h_kv, axis=1)
attn = q @ k^T / sqrt(d) * attn_score_scaling_factor
attn = softmax_with_mask(attn, casual_mask, axis=-1)
output = attn @ v
output -> [b, s, h_q * d] # reshape
FlashInfer Path
FlashInfer is used when all of the following conditions are met:
- FlashInfer is available (
_extern.get_store().flashinferis truthy) attn_score_scaling_factor == 1.0- All of Q, K, V have dtype
"float16" - The group size (
h_q // h_kv) is in[1, 4, 6, 8] - The head dimension
dis 128
if (
_extern.get_store().flashinfer
and attn_score_scaling_factor == 1.0
and q.dtype == "float16"
and k.dtype == "float16"
and v.dtype == "float16"
):
When FlashInfer conditions are met, the function selects between decode and prefill modes:
if isinstance(s, int) and s == 1:
func = "decode" # single-token decode
else:
func = "prefill" # multi-token prefill
Decode mode calls flashinfer.single_decode:
return op.extern(
name="flashinfer.single_decode",
args=[q, k, v, scratch, qkv_layout, rotary_mode, rope_scale, rope_theta],
out=nn.Tensor.placeholder((b, s, h_q * d), dtype="float16"),
)
Prefill mode calls flashinfer.single_prefill with additional parameters for causal masking and FP16 QK reduction:
return op.extern(
name="flashinfer.single_prefill",
args=[q, k, v, scratch, casual, qkv_layout, rotary_mode, fp16_qk, rope_scale, rope_theta],
out=nn.Tensor.placeholder((b, s, h_q * d), dtype="float16"),
)
Both modes allocate a 32MB scratchpad buffer (8192 * 1024 float32 elements) for FlashInfer's internal workspace.
FlashInfer Configuration Constants
| Constant | Value | Meaning |
|---|---|---|
qkv_layout |
0 | "NHD" layout: sequence length, num_heads, head_dim |
rotary_mode |
0 | "kNone" -- no rotary embedding applied inside attention |
casual |
1 | Causal masking enabled |
fp16_qk |
1 (or 0) | Enable FP16 QK reduction; disabled when qk_dtype == "float32"
|
Fallback Path
When FlashInfer cannot be used, the function falls back to TVM's TIR-based attention:
def _fallback():
from tvm.relax.frontend.nn.llm.kv_cache import _attention_sequence_prefill
# Reshape K, V from 3D to 4D if needed
if k.ndim == 3:
k = op.reshape(k, [b, t, h_kv, d])
if v.ndim == 3:
v = op.reshape(v, [b, t, h_kv, d])
# Expand KV heads for GQA
if h_kv != h_q:
k = k.repeat(h_q // h_kv, axis=2)
v = v.repeat(h_q // h_kv, axis=2)
target = tvm.target.Target("cuda")
attn_output, _ = op.tensor_ir_op(
_attention_sequence_prefill(h_kv=h_kv, h_q=h_q, d=d, dtype=q.dtype, target=target),
"sequence_prefill",
[q, k, v],
[
Tensor.placeholder([b, s, h_q, d], q.dtype),
Tensor.placeholder([b, s, h_q], q.dtype),
],
)
output = op.reshape(attn_output, shape=(b, s, h_q * d))
return output
The fallback:
- Reshapes 3D K/V tensors to 4D.
- Repeats KV heads for grouped-query attention (GQA) support.
- Calls
_attention_sequence_prefillas a TIR operator targeting CUDA. - Reshapes the output from
[b, s, h_q, d]to[b, s, h_q * d].
Design Notes
- The function supports grouped-query attention (GQA) where
h_q != h_kv, with the group size beingh_q // h_kv. - The
qk_dtypeparameter controls numerical precision for the QK matmul: when set to"float32", FlashInfer disables FP16 QK reduction for improved accuracy. - FlashInfer unsupported configurations silently fall back to the TIR implementation after a one-time warning.
Categories
- Attention Operator
- FlashInfer Integration
- TVM TIR
- GPU Kernels
- Grouped-Query Attention