Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mlc ai Mlc llm Fuse Transpose Matmul

From Leeroopedia


Overview

FuseTransposeMatmul is a TVM compiler pass that fuses transpose and matrix multiplication operations into a single "NT_matmul" (non-transposed A times transposed-storage B) kernel. Instead of materializing the transposed tensor in memory and then performing the matmul, this pass generates a fused tensor expression (TE) compute that reads B in its original (non-transposed) layout while computing the equivalent of matmul(A, transpose(B)).

File: python/mlc_llm/compiler_pass/fuse_transpose_matmul.py

Architecture

The pass operates in two phases:

  1. Pattern-based fusion -- Uses relax.transform.FuseOpsByPattern to identify and group permute_dims + matmul into composite functions
  2. Code generation -- Replaces the composite functions with fused TE compute definitions via the _TransposeMatmulFuser mutator

Class: FuseTransposeMatmul

@tvm.transform.module_pass(opt_level=0, name="FuseTransposeMatmul")
class FuseTransposeMatmul:
    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
        mod = relax.transform.FuseOpsByPattern(
            [("transpose_matmul_fuse", *_pattern())]
        )(mod)
        transpose_matmul_codegen = _TransposeMatmulFuser(mod)
        for g_var, func in mod.functions_items():
            if isinstance(func, relax.Function):
                func = transpose_matmul_codegen.visit_expr(func)
                transpose_matmul_codegen.builder_.update_func(g_var, func)
        return transpose_matmul_codegen.builder_.get()

Pattern Definition

The _pattern function defines the DPL (Dataflow Pattern Language) pattern for matching:

def _pattern():
    w = wildcard()
    x = wildcard()
    wT = is_op("relax.permute_dims")(w)
    o = is_op("relax.matmul")(x, wT)
    annotations = {"o": o, "w": w, "x": x, "wT": wT}

    def _check(context: relax.transform.PatternCheckContext) -> bool:
        transpose_call = context.annotated_expr["wT"]
        ndim = transpose_call.args[0].struct_info.ndim
        if ndim == -1:
            return False
        if ndim == 2 and transpose_call.attrs.axes is None:
            return True
        axes = list(range(ndim))
        axes[-1], axes[-2] = axes[-2], axes[-1]
        return list(transpose_call.attrs.axes) == axes

    return o, annotations, _check

Pattern: matmul(x, permute_dims(w))

Check conditions:

  • The number of dimensions must be known (not -1)
  • For 2D tensors: default axes (None) automatically means a simple transpose
  • For higher-dimensional tensors: the axes must swap only the last two dimensions

Class: _TransposeMatmulFuser

The mutator replaces composite function calls with fused TE computations:

@mutator
class _TransposeMatmulFuser(PyExprMutator):
    def visit_call_(self, call: relax.Call) -> relax.Expr:
        # ...
        if isinstance(call.op, relax.GlobalVar):
            function = self.builder_.get()[call.op]
            if (
                "Composite" in function.attrs
                and function.attrs["Composite"] == "transpose_matmul_fuse"
            ):
                out_dtype = function.ret_struct_info.dtype
                return self.builder_.call_te(
                    te_transposed_matmul,
                    call.args[1],   # w (non-transposed)
                    call.args[0],   # x
                    primfunc_name_hint="NT_matmul",
                )
        return super().visit_call_(call)

When a call targets a composite function with the attribute "Composite" == "transpose_matmul_fuse", it is replaced with a call_te invocation of the te_transposed_matmul compute function.

TE Compute: te_transposed_matmul

The inner function te_transposed_matmul generates a tensor expression that computes A @ B.T without materializing the transpose. It handles several edge cases:

1D input handling:

if len(a_shape) == 1:
    a_prepended = True
    a_shape.insert(0, 1)
if len(b_shape) == 1:
    b_appended = True
    b_shape.append(1)

Broadcasting: The compute handles broadcasting between tensors of different rank by computing an offset and iterating over the larger tensor's batch dimensions. It checks whether each dimension is 1 (for broadcasting) or must match.

Reduction:

k = te.reduce_axis((0, a_shape[-1]), name="k")

def multiply_compute(idx_reduce):
    # ... index computation with broadcasting ...
    # Key: b_indices uses idx_reduce as the LAST index (not second-to-last)
    # This effectively reads B in its original layout, achieving the transpose
    if dtype != "":
        return a(*a_indices).astype(dtype) * b(*b_indices).astype(dtype)
    return a(*a_indices) * b(*b_indices)

return te.sum(multiply_compute(k), axis=k)

The critical insight is in the index computation: for tensor B, the reduction axis k is placed as the last index (matching B's original non-transposed layout), while the spatial output index is used for B's second-to-last dimension. This reads B in its stored order while computing the equivalent of A @ B^T.

Output dtype handling: The out_dtype variable is captured from the composite function's return type via a nonlocal closure. When non-empty, input elements are cast to this dtype before multiplication.

Output

The generated TIR function is named "NT_matmul" (Non-Transposed matmul) to indicate that B is read in its original storage order.

Transformation Summary

Before After
matmul(x, permute_dims(w)) call_tir("NT_matmul", [w, x])

Note the argument order is swapped: w (the weight, originally transposed) comes first, followed by x.

Dependencies

  • tvm -- Core TVM framework
  • tvm.relax.dpl.pattern -- Dataflow pattern language (is_op, wildcard)
  • tvm.relax.expr_functor -- PyExprMutator and @mutator decorator
  • tvm.te -- Tensor expression for defining the fused compute

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment