Implementation:Mlc ai Mlc llm Fuse Transpose Matmul
Overview
FuseTransposeMatmul is a TVM compiler pass that fuses transpose and matrix multiplication operations into a single "NT_matmul" (non-transposed A times transposed-storage B) kernel. Instead of materializing the transposed tensor in memory and then performing the matmul, this pass generates a fused tensor expression (TE) compute that reads B in its original (non-transposed) layout while computing the equivalent of matmul(A, transpose(B)).
File: python/mlc_llm/compiler_pass/fuse_transpose_matmul.py
Architecture
The pass operates in two phases:
- Pattern-based fusion -- Uses
relax.transform.FuseOpsByPatternto identify and grouppermute_dims + matmulinto composite functions - Code generation -- Replaces the composite functions with fused TE compute definitions via the
_TransposeMatmulFusermutator
Class: FuseTransposeMatmul
@tvm.transform.module_pass(opt_level=0, name="FuseTransposeMatmul")
class FuseTransposeMatmul:
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
mod = relax.transform.FuseOpsByPattern(
[("transpose_matmul_fuse", *_pattern())]
)(mod)
transpose_matmul_codegen = _TransposeMatmulFuser(mod)
for g_var, func in mod.functions_items():
if isinstance(func, relax.Function):
func = transpose_matmul_codegen.visit_expr(func)
transpose_matmul_codegen.builder_.update_func(g_var, func)
return transpose_matmul_codegen.builder_.get()
Pattern Definition
The _pattern function defines the DPL (Dataflow Pattern Language) pattern for matching:
def _pattern():
w = wildcard()
x = wildcard()
wT = is_op("relax.permute_dims")(w)
o = is_op("relax.matmul")(x, wT)
annotations = {"o": o, "w": w, "x": x, "wT": wT}
def _check(context: relax.transform.PatternCheckContext) -> bool:
transpose_call = context.annotated_expr["wT"]
ndim = transpose_call.args[0].struct_info.ndim
if ndim == -1:
return False
if ndim == 2 and transpose_call.attrs.axes is None:
return True
axes = list(range(ndim))
axes[-1], axes[-2] = axes[-2], axes[-1]
return list(transpose_call.attrs.axes) == axes
return o, annotations, _check
Pattern: matmul(x, permute_dims(w))
Check conditions:
- The number of dimensions must be known (not -1)
- For 2D tensors: default axes (None) automatically means a simple transpose
- For higher-dimensional tensors: the axes must swap only the last two dimensions
Class: _TransposeMatmulFuser
The mutator replaces composite function calls with fused TE computations:
@mutator
class _TransposeMatmulFuser(PyExprMutator):
def visit_call_(self, call: relax.Call) -> relax.Expr:
# ...
if isinstance(call.op, relax.GlobalVar):
function = self.builder_.get()[call.op]
if (
"Composite" in function.attrs
and function.attrs["Composite"] == "transpose_matmul_fuse"
):
out_dtype = function.ret_struct_info.dtype
return self.builder_.call_te(
te_transposed_matmul,
call.args[1], # w (non-transposed)
call.args[0], # x
primfunc_name_hint="NT_matmul",
)
return super().visit_call_(call)
When a call targets a composite function with the attribute "Composite" == "transpose_matmul_fuse", it is replaced with a call_te invocation of the te_transposed_matmul compute function.
TE Compute: te_transposed_matmul
The inner function te_transposed_matmul generates a tensor expression that computes A @ B.T without materializing the transpose. It handles several edge cases:
1D input handling:
if len(a_shape) == 1:
a_prepended = True
a_shape.insert(0, 1)
if len(b_shape) == 1:
b_appended = True
b_shape.append(1)
Broadcasting: The compute handles broadcasting between tensors of different rank by computing an offset and iterating over the larger tensor's batch dimensions. It checks whether each dimension is 1 (for broadcasting) or must match.
Reduction:
k = te.reduce_axis((0, a_shape[-1]), name="k")
def multiply_compute(idx_reduce):
# ... index computation with broadcasting ...
# Key: b_indices uses idx_reduce as the LAST index (not second-to-last)
# This effectively reads B in its original layout, achieving the transpose
if dtype != "":
return a(*a_indices).astype(dtype) * b(*b_indices).astype(dtype)
return a(*a_indices) * b(*b_indices)
return te.sum(multiply_compute(k), axis=k)
The critical insight is in the index computation: for tensor B, the reduction axis k is placed as the last index (matching B's original non-transposed layout), while the spatial output index is used for B's second-to-last dimension. This reads B in its stored order while computing the equivalent of A @ B^T.
Output dtype handling:
The out_dtype variable is captured from the composite function's return type via a nonlocal closure. When non-empty, input elements are cast to this dtype before multiplication.
Output
The generated TIR function is named "NT_matmul" (Non-Transposed matmul) to indicate that B is read in its original storage order.
Transformation Summary
| Before | After |
|---|---|
matmul(x, permute_dims(w)) |
call_tir("NT_matmul", [w, x])
|
Note the argument order is swapped: w (the weight, originally transposed) comes first, followed by x.
Dependencies
tvm-- Core TVM frameworktvm.relax.dpl.pattern-- Dataflow pattern language (is_op,wildcard)tvm.relax.expr_functor--PyExprMutatorand@mutatordecoratortvm.te-- Tensor expression for defining the fused compute