Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mlc ai Mlc llm Dispatch Triton Kernel

From Leeroopedia


Overview

DispatchTritonKernel is a TVM compiler pass that dispatches generic Triton kernel calls in the IRModule to specific TIR implementations. It identifies calls to mlc.triton.* external functions and rewrites them into concrete TIR PrimFunc calls with target-specific optimizations. Currently, the pass handles two FP8 block-scaled matrix multiplication kernels used for quantized inference.

File: python/mlc_llm/compiler_pass/dispatch_triton_kernel.py

Architecture

The pass uses the mutator pattern, consisting of:

  • DispatchTritonKernel -- The outer TVM module pass that gates execution on CUDA targets
  • _Rewriter -- An inner PyExprMutator that traverses the IR and rewrites matching calls

Class: DispatchTritonKernel

@tvm.transform.module_pass(opt_level=0, name="DispatchTritonKernel")
class DispatchTritonKernel:
    def __init__(self, target: tvm.target.Target) -> None:
        self.target = target

    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
        if self.target.kind.name != "cuda":
            return mod
        return _Rewriter(mod, self.target).transform()

The pass only operates when the compilation target is CUDA. For all other targets, the module is returned unmodified.

Class: _Rewriter

The _Rewriter is a PyExprMutator that walks all Relax functions in the module and rewrites Triton kernel calls.

@mutator
class _Rewriter(PyExprMutator):
    def __init__(self, mod: IRModule, target: tvm.target.Target) -> None:
        super().__init__(mod)
        self.mod = mod
        self.target = target
        self.extern_mods: List[tvm.runtime.Module] = []

transform

The entry point iterates over all Relax functions, visits and rewrites each, then attaches any accumulated external modules to the final IRModule:

def transform(self) -> tvm.IRModule:
    for g_var, func in self.mod.functions_items():
        if not isinstance(func, relax.Function):
            continue
        new_func = self.visit_expr(func)
        self.builder_.update_func(g_var, new_func)

    mod = self.builder_.finalize()
    mod_attrs = dict(mod.attrs) if mod.attrs else {}
    mod = mod.with_attr(
        "external_mods", list(mod_attrs.get("external_mods", [])) + self.extern_mods
    )
    return mod

visit_call_

The core pattern matching logic. A call is rewritten only if:

  1. The operation is relax.call_dps_packed
  2. The first argument is a relax.ExternFunc
  3. The function's global symbol starts with "mlc.triton."
def visit_call_(self, call: relax.Call) -> relax.Expr:
    call = super().visit_call_(call)
    if (
        call.op != tvm.ir.Op.get("relax.call_dps_packed")
        or not isinstance(call.args[0], relax.ExternFunc)
        or not str(call.args[0].global_symbol).startswith("mlc.triton.")
    ):
        return call

    global_symbol = str(call.args[0].global_symbol)
    if global_symbol == "mlc.triton.w8a8_block_fp8_matmul":
        return self.w8a8_block_fp8_matmul(call.args[1].fields, call.struct_info)
    if global_symbol == "mlc.triton.w8a8_block_fp8_group_matmul":
        return self.w8a8_block_fp8_group_matmul(call.args[1].fields, call.struct_info)
    raise ValueError(f"Unknown mlc.triton kernel identifier: {global_symbol}")

Supported Triton Kernels

w8a8_block_fp8_matmul

Handles the mlc.triton.w8a8_block_fp8_matmul kernel, a W8A8 block-scaled FP8 matrix multiplication. The call expects 16 arguments:

Arguments Description
args[0:4] x, weight, x_scale, weight_scale -- data tensors
args[4:14] N, K, block_n, block_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M, num_warps, num_stages -- kernel parameters (integer constants)
args[14:16] in_dtype, out_dtype -- data type specifications
def w8a8_block_fp8_matmul(self, args, out_sinfo) -> relax.Expr:
    assert len(args) == 16
    x, weight, x_scale, weight_scale = args[:4]
    N, K, block_n, block_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, \
        GROUP_SIZE_M, num_warps, num_stages = [arg.value.value for arg in args[4:14]]
    in_dtype, out_dtype = str(args[14].value), str(args[15].value)

    prim_func, func_name = get_tir_w8a8_block_fp8_matmul(
        N, K, block_n, block_k, in_dtype, out_dtype,
        BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M,
        num_warps, num_stages, self.extern_mods,
    )
    # ... add to module or get existing global var ...
    return relax.call_tir(gv, [x, weight, x_scale, weight_scale], out_sinfo=out_sinfo)

w8a8_block_fp8_group_matmul

Handles the mlc.triton.w8a8_block_fp8_group_matmul kernel, a grouped variant for Mixture-of-Experts (MoE) workloads. The call expects 19 arguments:

Arguments Description
args[0:6] x, weight, x_scale, weight_scale, expert_ids, indptr -- data tensors
args[6:17] N, K, num_experts, block_n, block_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M, num_warps, num_stages -- kernel parameters
args[17:19] in_dtype, out_dtype -- data type specifications

The group variant includes additional expert_ids and indptr tensors and a num_experts parameter to route inputs to different expert weights.

TIR Function Management

Both kernel methods use a deduplication strategy. The helper functions (get_tir_w8a8_block_fp8_matmul and get_tir_w8a8_block_fp8_group_matmul) return either a new PrimFunc or None:

if prim_func is None:
    # The TIR function is already in the IRModule
    gv = self.builder_.get().get_global_var(func_name)
else:
    # Add the TIR function to the IRModule
    gv = self.builder_.add_func(prim_func, func_name)

When None is returned, the function was already added (from a previous call with the same parameters), and the existing global variable is reused. This avoids duplicating TIR functions for repeated kernel calls with identical configurations.

External Modules

External modules accumulated during rewriting (e.g., pre-compiled Triton kernel binaries) are appended to the IRModule's "external_mods" attribute at the end of the transformation.

Dependencies

  • tvm -- Core TVM framework
  • tvm.relax.expr_functor -- PyExprMutator and @mutator decorator
  • mlc_llm.op.triton -- get_tir_w8a8_block_fp8_matmul, get_tir_w8a8_block_fp8_group_matmul
  • mlc_llm.support.logging -- MLC-LLM logging

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment