Implementation:Mlc ai Mlc llm Triton Ops
| Knowledge Sources | |
|---|---|
| Domains | Triton Kernels, FP8 Quantization, MoE GEMM, LLM Operators |
| Last Updated | 2026-02-09 19:00 GMT |
Overview
Triton-based GPU kernel operators for FP8 block-scale quantized matrix multiplication, including both standard GEMM and group GEMM for Mixture of Experts (MoE) architectures.
Description
The triton module provides high-performance Triton GPU kernels integrated with TVM's compilation framework for FP8 quantized inference. The module contains:
Triton Kernel Functions:
- _get_triton_w8a8_block_fp8_gemm: A Triton kernel for block-scale FP8 matrix multiplication (A * B^T with per-block scales). Adapted from the SGLang project. Implements tiled GEMM with per-block FP8 dequantization using activation scales (As) and weight scales (Bs). Uses L2 cache-friendly grouped ordering of thread blocks.
- _get_triton_w8a8_block_fp8_group_gemm: A Triton kernel for MoE group GEMM, where different token groups are multiplied by different expert weight matrices. Uses expert_ids and indptr arrays to map thread blocks to experts. Supports variable-length token groups per expert.
TIR Wrapper Functions:
- get_tir_w8a8_block_fp8_matmul: Wraps the standard FP8 GEMM Triton kernel into a TVM TIR PrimFunc. Manages external module caching to avoid duplicate compilation.
- get_tir_w8a8_block_fp8_group_matmul: Wraps the MoE group GEMM Triton kernel into a TVM TIR PrimFunc. Also manages external module caching.
Helper Functions:
- _compute_expert_id_per_block: A TIR kernel that assigns expert IDs to each thread block based on the indptr tensor for group GEMM. Handles padding with -1 sentinel values for unused blocks.
High-Level Operators:
- fp8_groupwise_scaled_gemm: The user-facing operator for block-scale FP8 GEMM. Accepts input tensors with scales and dispatches to the Triton kernel via
nn.extern.
- fp8_groupwise_scaled_group_gemm: The user-facing operator for MoE block-scale FP8 group GEMM. Computes expert IDs per block and dispatches to the Triton group GEMM kernel.
The Triton import is optional; the module gracefully handles missing Triton installations by setting the module references to None.
Usage
Use these operators for FP8 block-scale quantized inference in MLC LLM, particularly when the CUTLASS backend is unavailable. They serve as the Triton-based alternative for both standard linear layers and MoE expert layers that use FP8 quantization. The high-level functions (fp8_groupwise_scaled_gemm and fp8_groupwise_scaled_group_gemm) are called from the block-scale quantization module as a fallback path.
Code Reference
Source Location
- Repository: Mlc_ai_Mlc_llm
- File: python/mlc_llm/op/triton.py
Signature
def fp8_groupwise_scaled_gemm(
x: nn.Tensor,
x_scale: nn.Tensor,
weight: nn.Tensor,
weight_scale: nn.Tensor,
block_size: Tuple[int, int],
out_dtype: str,
) -> nn.Tensor
def fp8_groupwise_scaled_group_gemm(
x: nn.Tensor,
x_scale: nn.Tensor,
weight: nn.Tensor,
weight_scale: nn.Tensor,
indptr: nn.Tensor,
block_size: Tuple[int, int],
out_dtype: str,
) -> nn.Tensor
def get_tir_w8a8_block_fp8_matmul(
N: int, K: int, block_n: int, block_k: int,
in_dtype: Literal["float8_e4m3fn"],
out_dtype: Literal["float16", "bfloat16"],
BLOCK_SIZE_M: int, BLOCK_SIZE_N: int, BLOCK_SIZE_K: int,
GROUP_SIZE_M: int, num_warps: int, num_stages: int,
extern_mods: List[tvm.runtime.Module],
) -> Tuple[Optional[tvm.tir.PrimFunc], str]
def get_tir_w8a8_block_fp8_group_matmul(
N: int, K: int, num_experts: int, block_n: int, block_k: int,
in_dtype: Literal["float8_e4m3fn"],
out_dtype: Literal["float16", "bfloat16"],
BLOCK_SIZE_M: int, BLOCK_SIZE_N: int, BLOCK_SIZE_K: int,
GROUP_SIZE_M: int, num_warps: int, num_stages: int,
extern_mods: List[tvm.runtime.Module],
) -> Tuple[Optional[tvm.tir.PrimFunc], str]
Import
from mlc_llm.op.triton import (
fp8_groupwise_scaled_gemm,
fp8_groupwise_scaled_group_gemm,
)
I/O Contract
fp8_groupwise_scaled_gemm
| Parameter | Type | Description |
|---|---|---|
| x | nn.Tensor | Input tensor [m, k], dtype float8_e4m3fn |
| x_scale | nn.Tensor | Activation scale [m, k // block_size[1]], dtype float32 |
| weight | nn.Tensor | Weight tensor [n, k], dtype float8_e4m3fn |
| weight_scale | nn.Tensor | Weight scale [n // block_size[0], k // block_size[1]], dtype float32 |
| block_size | Tuple[int, int] | Block dimensions for quantization (block_n, block_k) |
| out_dtype | str | Output data type ("float16" or "bfloat16") |
| Return | Type | Description |
|---|---|---|
| out | nn.Tensor | Output tensor [m, n], dtype out_dtype |
fp8_groupwise_scaled_group_gemm
| Parameter | Type | Description |
|---|---|---|
| x | nn.Tensor | Input tensor [m, k], dtype float8_e4m3fn |
| x_scale | nn.Tensor | Activation scale [m, k // block_size[1]], dtype float32 |
| weight | nn.Tensor | Expert weight tensor [num_experts, n, k], dtype float8_e4m3fn |
| weight_scale | nn.Tensor | Weight scale [num_experts, n // block_size[0], k // block_size[1]], float32 |
| indptr | nn.Tensor | Expert boundary pointers [num_experts + 1], int32 |
| block_size | Tuple[int, int] | Block dimensions for quantization |
| out_dtype | str | Output data type ("float16" or "bfloat16") |
| Return | Type | Description |
|---|---|---|
| out | nn.Tensor | Output tensor [m, n], dtype out_dtype |
Triton Kernel Configuration
| Parameter | Default Value | Description |
|---|---|---|
| BLOCK_SIZE_M | 64 | Tile size along batch dimension |
| BLOCK_SIZE_N | block_size[0] | Tile size along output dimension |
| BLOCK_SIZE_K | block_size[1] | Tile size along reduction dimension |
| GROUP_SIZE_M | 32 | Thread block grouping for L2 cache optimization |
| num_warps | 4 | Number of warps per thread block |
| num_stages | 3 | Number of pipeline stages |
Usage Examples
from mlc_llm.op.triton import fp8_groupwise_scaled_gemm, fp8_groupwise_scaled_group_gemm
# Standard FP8 block-scale GEMM
output = fp8_groupwise_scaled_gemm(
x=x_fp8, # [m, k] float8_e4m3fn
x_scale=x_scale, # [m, k // 128] float32
weight=weight_fp8, # [n, k] float8_e4m3fn
weight_scale=w_scale, # [n // 128, k // 128] float32
block_size=(128, 128),
out_dtype="bfloat16",
)
# MoE group GEMM with expert routing
output = fp8_groupwise_scaled_group_gemm(
x=x_fp8, # [m, k] float8_e4m3fn
x_scale=x_scale, # [m, k // 128] float32
weight=expert_weights, # [num_experts, n, k] float8_e4m3fn
weight_scale=w_scale, # [num_experts, n // 128, k // 128] float32
indptr=indptr, # [num_experts + 1] int32
block_size=(128, 128),
out_dtype="bfloat16",
)