Implementation:Mlc ai Mlc llm Fuse Add Norm
Overview
FuseAddRMSNorm is a TVM compiler pass that fuses element-wise addition and RMS normalization into a single TIR kernel. In transformer architectures, the pattern rms_norm(add(x1, x2), weight) appears frequently in residual connections followed by layer normalization. By fusing these two operations, the pass eliminates an intermediate memory read/write, improving inference performance.
File: python/mlc_llm/compiler_pass/fuse_add_norm.py
Purpose
In a typical transformer block, after attention or feed-forward computation, the residual is added back (x1 + x2) and then normalized (rms_norm). Without fusion, this requires:
- Writing the add result to global memory
- Reading it back for normalization
The fused kernel performs the add in local registers, computes the RMS norm on the local values, and writes both the normalized output and the raw addition result in a single pass.
Class: FuseAddRMSNorm
@tvm.transform.module_pass(opt_level=0, name="FuseAddRMSNorm")
class FuseAddRMSNorm:
def __init__(self, target: tvm.target.Target) -> None:
self.target = target
def transform_module(self, mod: tvm.IRModule, _ctx: tvm.transform.PassContext) -> tvm.IRModule:
return _FuseAddRMSNormRewriter(mod.clone(), self.target).transform()
The pass clones the module before modification and delegates to the _FuseAddRMSNormRewriter mutator.
Class: _FuseAddRMSNormRewriter
@mutator
class _FuseAddRMSNormRewriter(PyExprMutator):
def __init__(self, mod: tvm.IRModule, target: tvm.target.Target):
super().__init__(mod)
self.mod = mod
self.prefill_norm_gv: Optional[tvm.ir.GlobalVar] = None
self.decode_norm_gv: Optional[tvm.ir.GlobalVar] = None
self.TX = min(1024, get_max_num_threads_per_block(target))
Key attributes:
prefill_norm_gv/decode_norm_gv-- Cached global variables for the prefill and decode TIR functions, created lazily on first matchTX-- Thread block size, computed asmin(1024, max_threads_per_block)for the target device
Pattern Matching
The visit_call_ method matches the pattern rms_norm(add(x1, x2), weight):
def visit_call_(self, call: relax.Call) -> relax.Expr:
call = super().visit_call_(call)
# Match rms_norm call with float16 or bfloat16 output
if call.op != tvm.ir.Op.get("relax.nn.rms_norm") or call.struct_info.dtype not in [
"bfloat16", "float16",
]:
return call
weight = call.args[1]
eps = call.attrs.epsilon
y = self.lookup_binding(call.args[0])
# Check inner call is add
if not isinstance(y, relax.Call) or y.op != tvm.ir.Op.get("relax.add"):
return call
x1 = y.args[0]
x2 = y.args[1]
n, _, h = x1.struct_info.shape
h = int(h)
if h % self.TX != 0:
return call
Match conditions:
- The outer call must be
relax.nn.rms_norm - The output dtype must be
float16orbfloat16 - The first argument must be bound to a
relax.addcall - The hidden size
hmust be evenly divisible by the thread countTX
Prefill vs. Decode Dispatch
The pass generates separate TIR kernels for prefill and decode modes based on the first dimension n of the input shape:
is_prefill = n == 1
func_gv = self.prefill_norm_gv if is_prefill else self.decode_norm_gv
- Prefill mode (
n == 1): The batch dimension is 1 and the second dimension isseq_len(variable). Each thread block handles one sequence position. - Decode mode (
n != 1): The first dimension isbatch_size(variable) and the sequence dimension is 1. Each thread block handles one batch element.
Rewrite Output
The fused function produces two outputs via relax.call_tir:
tuple_output = self.builder_.emit(
relax.call_tir(func_gv, [x1, x2, weight], out_sinfo=[x1.struct_info, x2.struct_info])
)
new_o = relax.TupleGetItem(tuple_output, 0) # normalized output
new_y = self.builder_.emit(relax.TupleGetItem(tuple_output, 1)) # raw add result
self.set_var_remap(call.args[0].vid, new_y)
return new_o
The first output element is the RMS-normalized result. The second output element is the raw addition result (x1 + x2), which is remapped to replace the original add binding. This allows downstream consumers of the addition result to use the fused kernel's output.
TIR Kernel: Decode
The _get_add_rms_norm_decode function generates a scheduled TIR PrimFunc for the decode case:
def _get_add_rms_norm_decode(hidden_size: int, eps: float, TX: int, in_dtype: str):
inv_hidden_size = T.float32(1.0 / float(hidden_size))
add_local_size = hidden_size // TX
# ...
Buffer layout:
- Input A, B:
(batch_size, 1, hidden_size) - Weight C:
(hidden_size,) - Output O:
(batch_size, 1, hidden_size) - Add result:
(batch_size, 1, hidden_size)
Kernel structure:
- Element-wise add -- Each thread computes
add_local_sizeadditions, storing results in local registers - Write back -- The addition results are written to global memory
- Partial sum reduction -- Each thread computes the sum of squares of its local elements
- Cross-thread reduction -- Thread-level partial sums are reduced via shared memory
- Normalize and output -- The final RMS normalization is computed as
rsqrt(mean_sq + eps) * add_value * weight
The kernel uses blockIdx.x for batch elements and threadIdx.x (up to TX threads) for hidden dimension parallelism. Unroll annotations (pragma_auto_unroll_max_step: 256) enable aggressive loop unrolling.
TIR Kernel: Prefill
The _get_add_rms_norm_prefill function generates an analogous kernel for the prefill case:
Buffer layout:
- Input A, B:
(1, seq_len, hidden_size) - Weight C:
(hidden_size,) - Output O:
(1, seq_len, hidden_size) - Add result:
(1, seq_len, hidden_size)
The structure is identical to the decode kernel but iterates over seq_len positions instead of batch elements along blockIdx.x.
Supported Data Types
Both kernels only support float16 and bfloat16 input types. Internal computation for the reduction (sum of squares) is performed in float32 to maintain numerical stability. The final result is cast back to the input dtype.
Dependencies
tvm-- Core TVM frameworktvm.relax.analysis.remove_all_unused-- Dead code elimination for unused bindingstvm.relax.expr_functor--PyExprMutatorand@mutatordecoratortvm.script.tir-- TIR script DSL for kernel definitionsmlc_llm.support.max_thread_check.get_max_num_threads_per_block-- Target-aware thread limit query