Implementation:Mlc ai Mlc llm Dispatch Triton Kernel
Overview
DispatchTritonKernel is a TVM compiler pass that dispatches generic Triton kernel calls in the IRModule to specific TIR implementations. It identifies calls to mlc.triton.* external functions and rewrites them into concrete TIR PrimFunc calls with target-specific optimizations. Currently, the pass handles two FP8 block-scaled matrix multiplication kernels used for quantized inference.
File: python/mlc_llm/compiler_pass/dispatch_triton_kernel.py
Architecture
The pass uses the mutator pattern, consisting of:
- DispatchTritonKernel -- The outer TVM module pass that gates execution on CUDA targets
- _Rewriter -- An inner
PyExprMutatorthat traverses the IR and rewrites matching calls
Class: DispatchTritonKernel
@tvm.transform.module_pass(opt_level=0, name="DispatchTritonKernel")
class DispatchTritonKernel:
def __init__(self, target: tvm.target.Target) -> None:
self.target = target
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
if self.target.kind.name != "cuda":
return mod
return _Rewriter(mod, self.target).transform()
The pass only operates when the compilation target is CUDA. For all other targets, the module is returned unmodified.
Class: _Rewriter
The _Rewriter is a PyExprMutator that walks all Relax functions in the module and rewrites Triton kernel calls.
@mutator
class _Rewriter(PyExprMutator):
def __init__(self, mod: IRModule, target: tvm.target.Target) -> None:
super().__init__(mod)
self.mod = mod
self.target = target
self.extern_mods: List[tvm.runtime.Module] = []
transform
The entry point iterates over all Relax functions, visits and rewrites each, then attaches any accumulated external modules to the final IRModule:
def transform(self) -> tvm.IRModule:
for g_var, func in self.mod.functions_items():
if not isinstance(func, relax.Function):
continue
new_func = self.visit_expr(func)
self.builder_.update_func(g_var, new_func)
mod = self.builder_.finalize()
mod_attrs = dict(mod.attrs) if mod.attrs else {}
mod = mod.with_attr(
"external_mods", list(mod_attrs.get("external_mods", [])) + self.extern_mods
)
return mod
visit_call_
The core pattern matching logic. A call is rewritten only if:
- The operation is
relax.call_dps_packed - The first argument is a
relax.ExternFunc - The function's global symbol starts with
"mlc.triton."
def visit_call_(self, call: relax.Call) -> relax.Expr:
call = super().visit_call_(call)
if (
call.op != tvm.ir.Op.get("relax.call_dps_packed")
or not isinstance(call.args[0], relax.ExternFunc)
or not str(call.args[0].global_symbol).startswith("mlc.triton.")
):
return call
global_symbol = str(call.args[0].global_symbol)
if global_symbol == "mlc.triton.w8a8_block_fp8_matmul":
return self.w8a8_block_fp8_matmul(call.args[1].fields, call.struct_info)
if global_symbol == "mlc.triton.w8a8_block_fp8_group_matmul":
return self.w8a8_block_fp8_group_matmul(call.args[1].fields, call.struct_info)
raise ValueError(f"Unknown mlc.triton kernel identifier: {global_symbol}")
Supported Triton Kernels
w8a8_block_fp8_matmul
Handles the mlc.triton.w8a8_block_fp8_matmul kernel, a W8A8 block-scaled FP8 matrix multiplication. The call expects 16 arguments:
| Arguments | Description |
|---|---|
| args[0:4] | x, weight, x_scale, weight_scale -- data tensors
|
| args[4:14] | N, K, block_n, block_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M, num_warps, num_stages -- kernel parameters (integer constants)
|
| args[14:16] | in_dtype, out_dtype -- data type specifications
|
def w8a8_block_fp8_matmul(self, args, out_sinfo) -> relax.Expr:
assert len(args) == 16
x, weight, x_scale, weight_scale = args[:4]
N, K, block_n, block_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, \
GROUP_SIZE_M, num_warps, num_stages = [arg.value.value for arg in args[4:14]]
in_dtype, out_dtype = str(args[14].value), str(args[15].value)
prim_func, func_name = get_tir_w8a8_block_fp8_matmul(
N, K, block_n, block_k, in_dtype, out_dtype,
BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M,
num_warps, num_stages, self.extern_mods,
)
# ... add to module or get existing global var ...
return relax.call_tir(gv, [x, weight, x_scale, weight_scale], out_sinfo=out_sinfo)
w8a8_block_fp8_group_matmul
Handles the mlc.triton.w8a8_block_fp8_group_matmul kernel, a grouped variant for Mixture-of-Experts (MoE) workloads. The call expects 19 arguments:
| Arguments | Description |
|---|---|
| args[0:6] | x, weight, x_scale, weight_scale, expert_ids, indptr -- data tensors
|
| args[6:17] | N, K, num_experts, block_n, block_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M, num_warps, num_stages -- kernel parameters
|
| args[17:19] | in_dtype, out_dtype -- data type specifications
|
The group variant includes additional expert_ids and indptr tensors and a num_experts parameter to route inputs to different expert weights.
TIR Function Management
Both kernel methods use a deduplication strategy. The helper functions (get_tir_w8a8_block_fp8_matmul and get_tir_w8a8_block_fp8_group_matmul) return either a new PrimFunc or None:
if prim_func is None:
# The TIR function is already in the IRModule
gv = self.builder_.get().get_global_var(func_name)
else:
# Add the TIR function to the IRModule
gv = self.builder_.add_func(prim_func, func_name)
When None is returned, the function was already added (from a previous call with the same parameters), and the existing global variable is reused. This avoids duplicating TIR functions for repeated kernel calls with identical configurations.
External Modules
External modules accumulated during rewriting (e.g., pre-compiled Triton kernel binaries) are appended to the IRModule's "external_mods" attribute at the end of the transformation.
Dependencies
tvm-- Core TVM frameworktvm.relax.expr_functor--PyExprMutatorand@mutatordecoratormlc_llm.op.triton--get_tir_w8a8_block_fp8_matmul,get_tir_w8a8_block_fp8_group_matmulmlc_llm.support.logging-- MLC-LLM logging