Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mlc ai Mlc llm Fuse FT Dequantize Matmul Epilogue

From Leeroopedia


Knowledge Sources
Domains Compiler Pass, Operator Fusion, Quantization
Last Updated 2026-02-09 19:00 GMT

Overview

A TVM compiler pass that fuses FasterTransformer dequantize-matmul operations with their epilogue operations (bias addition, activation functions, residual binary operations, and residual unary operations) into single fused kernels.

Description

The FuseFTDequantizeEpilogue pass is a TVM module-level transformation (opt_level=0) that performs operator fusion on FasterTransformer's quantized GEMM operations. Quantized models using FasterTransformer kernels typically decompose the computation into separate dequantize-matmul and epilogue steps. This pass fuses them together to reduce memory bandwidth and kernel launch overhead.

The pass applies four sequential fusion stages to each Relax function in the module:

fuse_bias fuses a subsequent relax.add operation into fastertransformer.gemm_fp16_int to produce fastertransformer.gemm_fp16_int_bias. The fusion adds the bias tensor and a computed bias_stride parameter to the kernel arguments. The fusion is only applied when no activation has been set yet (activation is "identity"). The bias_stride is computed based on whether the bias tensor is 1-dimensional or multi-dimensional.

fuse_activation fuses a subsequent activation function (relax.nn.silu, relax.nn.gelu, or relax.nn.relu) into either fastertransformer.gemm_fp16_int or fastertransformer.gemm_fp16_int_bias. The activation name (with the "relax.nn." prefix stripped) replaces the "identity" activation parameter.

fuse_residual_binary fuses a subsequent binary operation (relax.add or relax.multiply) into fastertransformer.gemm_fp16_int_bias to produce fastertransformer.gemm_fp16_int_bias_residual. The fused kernel includes the residual tensor, a binary operation type ("plus" or "multiply"), and an initial unary operation of "identity". This fusion is only applied when bias_stride is 0.

fuse_residual_unary fuses a subsequent activation function into fastertransformer.gemm_fp16_int_bias_residual, replacing the "identity" unary operation parameter with the matched activation name.

All four fusion functions use TVM's Relax dataflow pattern language (DPL) with rewrite_call for pattern matching and rewriting. The patterns use is_op to match specific operations and wildcard for variable operands.

Usage

This pass is used during the MLC LLM model compilation pipeline when compiling quantized models that leverage FasterTransformer GEMM kernels. It is applied as part of the optimization pass sequence to reduce the number of kernel launches and intermediate memory allocations in the quantized inference path.

Code Reference

Source Location

Signature

@tvm.transform.module_pass(opt_level=0, name="FuseDequantizeEpilogue")
class FuseFTDequantizeEpilogue:
    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
        ...

# Individual fusion functions (applied sequentially)
def fuse_bias(func: relax.Function) -> relax.Function: ...
def fuse_activation(func: relax.Function) -> relax.Function: ...
def fuse_residual_binary(func: relax.Function) -> relax.Function: ...
def fuse_residual_unary(func: relax.Function) -> relax.Function: ...

Import

from mlc_llm.compiler_pass.fuse_ft_dequantize_matmul_epilogue import FuseFTDequantizeEpilogue

I/O Contract

Inputs

Name Type Required Description
mod IRModule Yes The TVM IRModule containing Relax functions with FasterTransformer GEMM calls.

Outputs

Name Type Description
mod IRModule The transformed IRModule with fused FasterTransformer operations.

The following kernel progression illustrates the fusion chain:

Stage Before After
fuse_bias gemm_fp16_int + add gemm_fp16_int_bias
fuse_activation gemm_fp16_int[_bias] + silu/gelu/relu gemm_fp16_int[_bias] (with activation param)
fuse_residual_binary gemm_fp16_int_bias + add/multiply gemm_fp16_int_bias_residual
fuse_residual_unary gemm_fp16_int_bias_residual + silu/gelu/relu gemm_fp16_int_bias_residual (with unary param)

Usage Examples

import tvm
from mlc_llm.compiler_pass.fuse_ft_dequantize_matmul_epilogue import FuseFTDequantizeEpilogue

# Apply the pass to an IRModule containing FasterTransformer GEMM operations
fuse_pass = FuseFTDequantizeEpilogue()
with tvm.transform.PassContext():
    optimized_mod = fuse_pass(mod)

# The optimized module will have patterns like:
# Before: gemm_fp16_int -> add(bias) -> silu -> add(residual)
# After:  gemm_fp16_int_bias_residual (with activation="silu", binary_op="plus")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment