Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mlc ai Mlc llm Fuse Add Norm

From Leeroopedia
Revision as of 15:49, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Mlc_ai_Mlc_llm_Fuse_Add_Norm.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Overview

FuseAddRMSNorm is a TVM compiler pass that fuses element-wise addition and RMS normalization into a single TIR kernel. In transformer architectures, the pattern rms_norm(add(x1, x2), weight) appears frequently in residual connections followed by layer normalization. By fusing these two operations, the pass eliminates an intermediate memory read/write, improving inference performance.

File: python/mlc_llm/compiler_pass/fuse_add_norm.py

Purpose

In a typical transformer block, after attention or feed-forward computation, the residual is added back (x1 + x2) and then normalized (rms_norm). Without fusion, this requires:

  1. Writing the add result to global memory
  2. Reading it back for normalization

The fused kernel performs the add in local registers, computes the RMS norm on the local values, and writes both the normalized output and the raw addition result in a single pass.

Class: FuseAddRMSNorm

@tvm.transform.module_pass(opt_level=0, name="FuseAddRMSNorm")
class FuseAddRMSNorm:
    def __init__(self, target: tvm.target.Target) -> None:
        self.target = target

    def transform_module(self, mod: tvm.IRModule, _ctx: tvm.transform.PassContext) -> tvm.IRModule:
        return _FuseAddRMSNormRewriter(mod.clone(), self.target).transform()

The pass clones the module before modification and delegates to the _FuseAddRMSNormRewriter mutator.

Class: _FuseAddRMSNormRewriter

@mutator
class _FuseAddRMSNormRewriter(PyExprMutator):
    def __init__(self, mod: tvm.IRModule, target: tvm.target.Target):
        super().__init__(mod)
        self.mod = mod
        self.prefill_norm_gv: Optional[tvm.ir.GlobalVar] = None
        self.decode_norm_gv: Optional[tvm.ir.GlobalVar] = None
        self.TX = min(1024, get_max_num_threads_per_block(target))

Key attributes:

  • prefill_norm_gv / decode_norm_gv -- Cached global variables for the prefill and decode TIR functions, created lazily on first match
  • TX -- Thread block size, computed as min(1024, max_threads_per_block) for the target device

Pattern Matching

The visit_call_ method matches the pattern rms_norm(add(x1, x2), weight):

def visit_call_(self, call: relax.Call) -> relax.Expr:
    call = super().visit_call_(call)

    # Match rms_norm call with float16 or bfloat16 output
    if call.op != tvm.ir.Op.get("relax.nn.rms_norm") or call.struct_info.dtype not in [
        "bfloat16", "float16",
    ]:
        return call

    weight = call.args[1]
    eps = call.attrs.epsilon
    y = self.lookup_binding(call.args[0])

    # Check inner call is add
    if not isinstance(y, relax.Call) or y.op != tvm.ir.Op.get("relax.add"):
        return call

    x1 = y.args[0]
    x2 = y.args[1]
    n, _, h = x1.struct_info.shape
    h = int(h)
    if h % self.TX != 0:
        return call

Match conditions:

  • The outer call must be relax.nn.rms_norm
  • The output dtype must be float16 or bfloat16
  • The first argument must be bound to a relax.add call
  • The hidden size h must be evenly divisible by the thread count TX

Prefill vs. Decode Dispatch

The pass generates separate TIR kernels for prefill and decode modes based on the first dimension n of the input shape:

is_prefill = n == 1
func_gv = self.prefill_norm_gv if is_prefill else self.decode_norm_gv
  • Prefill mode (n == 1): The batch dimension is 1 and the second dimension is seq_len (variable). Each thread block handles one sequence position.
  • Decode mode (n != 1): The first dimension is batch_size (variable) and the sequence dimension is 1. Each thread block handles one batch element.

Rewrite Output

The fused function produces two outputs via relax.call_tir:

tuple_output = self.builder_.emit(
    relax.call_tir(func_gv, [x1, x2, weight], out_sinfo=[x1.struct_info, x2.struct_info])
)
new_o = relax.TupleGetItem(tuple_output, 0)       # normalized output
new_y = self.builder_.emit(relax.TupleGetItem(tuple_output, 1))  # raw add result
self.set_var_remap(call.args[0].vid, new_y)
return new_o

The first output element is the RMS-normalized result. The second output element is the raw addition result (x1 + x2), which is remapped to replace the original add binding. This allows downstream consumers of the addition result to use the fused kernel's output.

TIR Kernel: Decode

The _get_add_rms_norm_decode function generates a scheduled TIR PrimFunc for the decode case:

def _get_add_rms_norm_decode(hidden_size: int, eps: float, TX: int, in_dtype: str):
    inv_hidden_size = T.float32(1.0 / float(hidden_size))
    add_local_size = hidden_size // TX
    # ...

Buffer layout:

  • Input A, B: (batch_size, 1, hidden_size)
  • Weight C: (hidden_size,)
  • Output O: (batch_size, 1, hidden_size)
  • Add result: (batch_size, 1, hidden_size)

Kernel structure:

  1. Element-wise add -- Each thread computes add_local_size additions, storing results in local registers
  2. Write back -- The addition results are written to global memory
  3. Partial sum reduction -- Each thread computes the sum of squares of its local elements
  4. Cross-thread reduction -- Thread-level partial sums are reduced via shared memory
  5. Normalize and output -- The final RMS normalization is computed as rsqrt(mean_sq + eps) * add_value * weight

The kernel uses blockIdx.x for batch elements and threadIdx.x (up to TX threads) for hidden dimension parallelism. Unroll annotations (pragma_auto_unroll_max_step: 256) enable aggressive loop unrolling.

TIR Kernel: Prefill

The _get_add_rms_norm_prefill function generates an analogous kernel for the prefill case:

Buffer layout:

  • Input A, B: (1, seq_len, hidden_size)
  • Weight C: (hidden_size,)
  • Output O: (1, seq_len, hidden_size)
  • Add result: (1, seq_len, hidden_size)

The structure is identical to the decode kernel but iterates over seq_len positions instead of batch elements along blockIdx.x.

Supported Data Types

Both kernels only support float16 and bfloat16 input types. Internal computation for the reduction (sum of squares) is performed in float32 to maintain numerical stability. The final result is cast back to the input dtype.

Dependencies

  • tvm -- Core TVM framework
  • tvm.relax.analysis.remove_all_unused -- Dead code elimination for unused bindings
  • tvm.relax.expr_functor -- PyExprMutator and @mutator decorator
  • tvm.script.tir -- TIR script DSL for kernel definitions
  • mlc_llm.support.max_thread_check.get_max_num_threads_per_block -- Target-aware thread limit query

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment