Implementation:Mlc ai Mlc llm Fuse FT Dequantize Matmul Epilogue
| Knowledge Sources | |
|---|---|
| Domains | Compiler Pass, Operator Fusion, Quantization |
| Last Updated | 2026-02-09 19:00 GMT |
Overview
A TVM compiler pass that fuses FasterTransformer dequantize-matmul operations with their epilogue operations (bias addition, activation functions, residual binary operations, and residual unary operations) into single fused kernels.
Description
The FuseFTDequantizeEpilogue pass is a TVM module-level transformation (opt_level=0) that performs operator fusion on FasterTransformer's quantized GEMM operations. Quantized models using FasterTransformer kernels typically decompose the computation into separate dequantize-matmul and epilogue steps. This pass fuses them together to reduce memory bandwidth and kernel launch overhead.
The pass applies four sequential fusion stages to each Relax function in the module:
fuse_bias fuses a subsequent relax.add operation into fastertransformer.gemm_fp16_int to produce fastertransformer.gemm_fp16_int_bias. The fusion adds the bias tensor and a computed bias_stride parameter to the kernel arguments. The fusion is only applied when no activation has been set yet (activation is "identity"). The bias_stride is computed based on whether the bias tensor is 1-dimensional or multi-dimensional.
fuse_activation fuses a subsequent activation function (relax.nn.silu, relax.nn.gelu, or relax.nn.relu) into either fastertransformer.gemm_fp16_int or fastertransformer.gemm_fp16_int_bias. The activation name (with the "relax.nn." prefix stripped) replaces the "identity" activation parameter.
fuse_residual_binary fuses a subsequent binary operation (relax.add or relax.multiply) into fastertransformer.gemm_fp16_int_bias to produce fastertransformer.gemm_fp16_int_bias_residual. The fused kernel includes the residual tensor, a binary operation type ("plus" or "multiply"), and an initial unary operation of "identity". This fusion is only applied when bias_stride is 0.
fuse_residual_unary fuses a subsequent activation function into fastertransformer.gemm_fp16_int_bias_residual, replacing the "identity" unary operation parameter with the matched activation name.
All four fusion functions use TVM's Relax dataflow pattern language (DPL) with rewrite_call for pattern matching and rewriting. The patterns use is_op to match specific operations and wildcard for variable operands.
Usage
This pass is used during the MLC LLM model compilation pipeline when compiling quantized models that leverage FasterTransformer GEMM kernels. It is applied as part of the optimization pass sequence to reduce the number of kernel launches and intermediate memory allocations in the quantized inference path.
Code Reference
Source Location
Signature
@tvm.transform.module_pass(opt_level=0, name="FuseDequantizeEpilogue")
class FuseFTDequantizeEpilogue:
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
...
# Individual fusion functions (applied sequentially)
def fuse_bias(func: relax.Function) -> relax.Function: ...
def fuse_activation(func: relax.Function) -> relax.Function: ...
def fuse_residual_binary(func: relax.Function) -> relax.Function: ...
def fuse_residual_unary(func: relax.Function) -> relax.Function: ...
Import
from mlc_llm.compiler_pass.fuse_ft_dequantize_matmul_epilogue import FuseFTDequantizeEpilogue
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| mod | IRModule | Yes | The TVM IRModule containing Relax functions with FasterTransformer GEMM calls. |
Outputs
| Name | Type | Description |
|---|---|---|
| mod | IRModule | The transformed IRModule with fused FasterTransformer operations. |
The following kernel progression illustrates the fusion chain:
| Stage | Before | After |
|---|---|---|
| fuse_bias | gemm_fp16_int + add | gemm_fp16_int_bias |
| fuse_activation | gemm_fp16_int[_bias] + silu/gelu/relu | gemm_fp16_int[_bias] (with activation param) |
| fuse_residual_binary | gemm_fp16_int_bias + add/multiply | gemm_fp16_int_bias_residual |
| fuse_residual_unary | gemm_fp16_int_bias_residual + silu/gelu/relu | gemm_fp16_int_bias_residual (with unary param) |
Usage Examples
import tvm
from mlc_llm.compiler_pass.fuse_ft_dequantize_matmul_epilogue import FuseFTDequantizeEpilogue
# Apply the pass to an IRModule containing FasterTransformer GEMM operations
fuse_pass = FuseFTDequantizeEpilogue()
with tvm.transform.PassContext():
optimized_mod = fuse_pass(mod)
# The optimized module will have patterns like:
# Before: gemm_fp16_int -> add(bias) -> silu -> add(residual)
# After: gemm_fp16_int_bias_residual (with activation="silu", binary_op="plus")