Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mlc ai Mlc llm Triton Ops

From Leeroopedia


Knowledge Sources
Domains Triton Kernels, FP8 Quantization, MoE GEMM, LLM Operators
Last Updated 2026-02-09 19:00 GMT

Overview

Triton-based GPU kernel operators for FP8 block-scale quantized matrix multiplication, including both standard GEMM and group GEMM for Mixture of Experts (MoE) architectures.

Description

The triton module provides high-performance Triton GPU kernels integrated with TVM's compilation framework for FP8 quantized inference. The module contains:

Triton Kernel Functions:

  • _get_triton_w8a8_block_fp8_gemm: A Triton kernel for block-scale FP8 matrix multiplication (A * B^T with per-block scales). Adapted from the SGLang project. Implements tiled GEMM with per-block FP8 dequantization using activation scales (As) and weight scales (Bs). Uses L2 cache-friendly grouped ordering of thread blocks.
  • _get_triton_w8a8_block_fp8_group_gemm: A Triton kernel for MoE group GEMM, where different token groups are multiplied by different expert weight matrices. Uses expert_ids and indptr arrays to map thread blocks to experts. Supports variable-length token groups per expert.

TIR Wrapper Functions:

  • get_tir_w8a8_block_fp8_matmul: Wraps the standard FP8 GEMM Triton kernel into a TVM TIR PrimFunc. Manages external module caching to avoid duplicate compilation.
  • get_tir_w8a8_block_fp8_group_matmul: Wraps the MoE group GEMM Triton kernel into a TVM TIR PrimFunc. Also manages external module caching.

Helper Functions:

  • _compute_expert_id_per_block: A TIR kernel that assigns expert IDs to each thread block based on the indptr tensor for group GEMM. Handles padding with -1 sentinel values for unused blocks.

High-Level Operators:

  • fp8_groupwise_scaled_gemm: The user-facing operator for block-scale FP8 GEMM. Accepts input tensors with scales and dispatches to the Triton kernel via nn.extern.
  • fp8_groupwise_scaled_group_gemm: The user-facing operator for MoE block-scale FP8 group GEMM. Computes expert IDs per block and dispatches to the Triton group GEMM kernel.

The Triton import is optional; the module gracefully handles missing Triton installations by setting the module references to None.

Usage

Use these operators for FP8 block-scale quantized inference in MLC LLM, particularly when the CUTLASS backend is unavailable. They serve as the Triton-based alternative for both standard linear layers and MoE expert layers that use FP8 quantization. The high-level functions (fp8_groupwise_scaled_gemm and fp8_groupwise_scaled_group_gemm) are called from the block-scale quantization module as a fallback path.

Code Reference

Source Location

Signature

def fp8_groupwise_scaled_gemm(
    x: nn.Tensor,
    x_scale: nn.Tensor,
    weight: nn.Tensor,
    weight_scale: nn.Tensor,
    block_size: Tuple[int, int],
    out_dtype: str,
) -> nn.Tensor

def fp8_groupwise_scaled_group_gemm(
    x: nn.Tensor,
    x_scale: nn.Tensor,
    weight: nn.Tensor,
    weight_scale: nn.Tensor,
    indptr: nn.Tensor,
    block_size: Tuple[int, int],
    out_dtype: str,
) -> nn.Tensor

def get_tir_w8a8_block_fp8_matmul(
    N: int, K: int, block_n: int, block_k: int,
    in_dtype: Literal["float8_e4m3fn"],
    out_dtype: Literal["float16", "bfloat16"],
    BLOCK_SIZE_M: int, BLOCK_SIZE_N: int, BLOCK_SIZE_K: int,
    GROUP_SIZE_M: int, num_warps: int, num_stages: int,
    extern_mods: List[tvm.runtime.Module],
) -> Tuple[Optional[tvm.tir.PrimFunc], str]

def get_tir_w8a8_block_fp8_group_matmul(
    N: int, K: int, num_experts: int, block_n: int, block_k: int,
    in_dtype: Literal["float8_e4m3fn"],
    out_dtype: Literal["float16", "bfloat16"],
    BLOCK_SIZE_M: int, BLOCK_SIZE_N: int, BLOCK_SIZE_K: int,
    GROUP_SIZE_M: int, num_warps: int, num_stages: int,
    extern_mods: List[tvm.runtime.Module],
) -> Tuple[Optional[tvm.tir.PrimFunc], str]

Import

from mlc_llm.op.triton import (
    fp8_groupwise_scaled_gemm,
    fp8_groupwise_scaled_group_gemm,
)

I/O Contract

fp8_groupwise_scaled_gemm

Parameter Type Description
x nn.Tensor Input tensor [m, k], dtype float8_e4m3fn
x_scale nn.Tensor Activation scale [m, k // block_size[1]], dtype float32
weight nn.Tensor Weight tensor [n, k], dtype float8_e4m3fn
weight_scale nn.Tensor Weight scale [n // block_size[0], k // block_size[1]], dtype float32
block_size Tuple[int, int] Block dimensions for quantization (block_n, block_k)
out_dtype str Output data type ("float16" or "bfloat16")
Return Type Description
out nn.Tensor Output tensor [m, n], dtype out_dtype

fp8_groupwise_scaled_group_gemm

Parameter Type Description
x nn.Tensor Input tensor [m, k], dtype float8_e4m3fn
x_scale nn.Tensor Activation scale [m, k // block_size[1]], dtype float32
weight nn.Tensor Expert weight tensor [num_experts, n, k], dtype float8_e4m3fn
weight_scale nn.Tensor Weight scale [num_experts, n // block_size[0], k // block_size[1]], float32
indptr nn.Tensor Expert boundary pointers [num_experts + 1], int32
block_size Tuple[int, int] Block dimensions for quantization
out_dtype str Output data type ("float16" or "bfloat16")
Return Type Description
out nn.Tensor Output tensor [m, n], dtype out_dtype

Triton Kernel Configuration

Parameter Default Value Description
BLOCK_SIZE_M 64 Tile size along batch dimension
BLOCK_SIZE_N block_size[0] Tile size along output dimension
BLOCK_SIZE_K block_size[1] Tile size along reduction dimension
GROUP_SIZE_M 32 Thread block grouping for L2 cache optimization
num_warps 4 Number of warps per thread block
num_stages 3 Number of pipeline stages

Usage Examples

from mlc_llm.op.triton import fp8_groupwise_scaled_gemm, fp8_groupwise_scaled_group_gemm

# Standard FP8 block-scale GEMM
output = fp8_groupwise_scaled_gemm(
    x=x_fp8,                # [m, k] float8_e4m3fn
    x_scale=x_scale,        # [m, k // 128] float32
    weight=weight_fp8,       # [n, k] float8_e4m3fn
    weight_scale=w_scale,    # [n // 128, k // 128] float32
    block_size=(128, 128),
    out_dtype="bfloat16",
)

# MoE group GEMM with expert routing
output = fp8_groupwise_scaled_group_gemm(
    x=x_fp8,                # [m, k] float8_e4m3fn
    x_scale=x_scale,        # [m, k // 128] float32
    weight=expert_weights,   # [num_experts, n, k] float8_e4m3fn
    weight_scale=w_scale,    # [num_experts, n // 128, k // 128] float32
    indptr=indptr,           # [num_experts + 1] int32
    block_size=(128, 128),
    out_dtype="bfloat16",
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment