Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mlc ai Mlc llm Attach Support Info

From Leeroopedia


Overview

attach_support_info.py contains a collection of TVM compiler passes that attach various supportive metadata and attributes to functions in an IRModule. These passes operate at optimization level 0 and are used during the MLC-LLM compilation pipeline to annotate Relax and TIR functions with information needed by downstream passes such as memory planning, CUDA graph capture, pipeline parallelism, and sequence length padding.

File: python/mlc_llm/compiler_pass/attach_support_info.py

Passes

AttachVariableBounds

Attaches TIR variable upper bounds to each Relax function, which primarily assists the memory planning pass in determining buffer sizes.

@tvm.transform.module_pass(opt_level=0, name="AttachVariableBounds")
class AttachVariableBounds:
    def __init__(self, variable_bounds: Dict[str, int]):
        # Specifically for RWKV workloads, which contains -1 max_seq_len
        self.variable_bounds = {k: v for k, v in variable_bounds.items() if v > 0}
        self.non_negative_var = ["vocab_size"]

    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
        for g_var, func in mod.functions_items():
            if isinstance(func, relax.Function):
                mod[g_var] = func.with_attr("tir_var_upper_bound", self.variable_bounds).with_attr(
                    "tir_non_negative_var", self.non_negative_var
                )
        return mod

Key details:

  • Filters out non-positive bounds (specifically handling RWKV models that may report -1 for max_seq_len)
  • Attaches two attributes: tir_var_upper_bound (the bounds dictionary) and tir_non_negative_var (hardcoded to ["vocab_size"])

AttachAdditionalPrimFuncs

Attaches extra TIR PrimFuncs directly to the IRModule. Each function is annotated with its global_symbol attribute.

@tvm.transform.module_pass(opt_level=0, name="AttachAdditionalPrimFuncs")
class AttachAdditionalPrimFuncs:
    def __init__(self, functions: Dict[str, tir.PrimFunc]):
        self.functions = functions

    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
        for func_name, func in self.functions.items():
            mod[func_name] = func.with_attr("global_symbol", func_name)
        return mod

AttachMemoryPlanAttr

Annotates all Relax functions to enable dynamic function output memory planning. This is a simple pass that adds the attribute relax.memory_plan_dynamic_func_output set to True on every Relax function.

@tvm.transform.module_pass(opt_level=0, name="AttachMemoryPlanAttr")
class AttachMemoryPlanAttr:
    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
        for g_var, func in mod.functions_items():
            if isinstance(func, relax.Function):
                mod[g_var] = func.with_attr("relax.memory_plan_dynamic_func_output", True)
        return mod

AttachCUDAGraphSymbolicCaptureHints

Attaches CUDA graph capture hints to specific Relax functions. These hints guide the CUDA graph rewriting pass by specifying which symbolic variables should be captured during graph recording.

@tvm.transform.module_pass(opt_level=0, name="AttachCUDAGraphCaptureHints")
class AttachCUDAGraphSymbolicCaptureHints:
    def __init__(self, hints: Dict[str, List[str]]):
        self.hints = hints

    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
        for g_var, func in mod.functions_items():
            func_name = g_var.name_hint
            if isinstance(func, relax.Function):
                if func_name in self.hints:
                    mod[g_var] = func.with_attr(
                        "relax.rewrite_cuda_graph.capture_symbolic_vars",
                        self.hints[func_name],
                    )
        return mod

The attribute relax.rewrite_cuda_graph.capture_symbolic_vars is a list of symbolic variable names for each function that should be captured when CUDA graphs are used.

AttachPipelineParallelStages

Attaches the number of pipeline parallel stages to relevant Relax functions. This pass only targets a specific set of model inference entry-point functions.

@tvm.transform.module_pass(opt_level=0, name="AttachPipelineParallelStages")
class AttachPipelineParallelStages:
    def __init__(self, pipeline_parallel_shards: int):
        self.pipeline_parallel_shards = pipeline_parallel_shards

    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
        for g_var, func in mod.functions_items():
            func_name = g_var.name_hint
            if not isinstance(func, relax.Function) or func_name not in [
                "prefill", "decode",
                "prefill_to_last_hidden_states", "decode_to_last_hidden_states",
                "batch_prefill", "batch_decode", "batch_verify",
                "batch_prefill_to_last_hidden_states",
                "batch_decode_to_last_hidden_states",
                "batch_verify_to_last_hidden_states",
            ]:
                continue
            mod[g_var] = func.with_attr("pipeline_parallel_stages", self.pipeline_parallel_shards)
        return mod

Targeted functions:

Category Function Names
Single-request prefill, decode
Hidden states prefill_to_last_hidden_states, decode_to_last_hidden_states
Batched batch_prefill, batch_decode, batch_verify
Batched hidden states batch_prefill_to_last_hidden_states, batch_decode_to_last_hidden_states, batch_verify_to_last_hidden_states

AttachSequenceLengthPaddingFactor

Determines and attaches a sequence length padding factor to the model metadata. This is specifically relevant for NVIDIA SM100a (Blackwell) architecture when using CUTLASS grouped/scaled GEMM operations.

@tvm.transform.module_pass(opt_level=0, name="AttachSequenceLengthPaddingFactor")
class AttachSequenceLengthPaddingFactor:
    def __init__(self, target: tvm.target.Target, metadata: Dict[str, Any]):
        self.target = target
        self.metadata = metadata

    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
        # Uses an internal visitor to scan the IRModule
        padding_factor = _Visitor(self.target).run(mod)
        if padding_factor > 1:
            self.metadata["seqlen_padding_factor"] = padding_factor
        return mod

The internal _Visitor class inspects all relax.call_dps_packed calls looking for specific CUTLASS kernels:

  • cutlass.groupwise_scaled_gemm_e4m3fn_e4m3fn
  • cutlass.groupwise_scaled_bmm_e4m3fn_e4m3fn

When such kernels are found on SM100a targets, the padding factor is set to the least common multiple of the current factor and 4 (using math.lcm). This ensures sequence lengths are padded to multiples of 4 for these particular CUTLASS operations.

Pass Summary

Pass Name TVM Registration Name Purpose Parameters
AttachVariableBounds AttachVariableBounds Memory planning bounds variable_bounds: Dict[str, int]
AttachAdditionalPrimFuncs AttachAdditionalPrimFuncs Add extra TIR functions functions: Dict[str, tir.PrimFunc]
AttachMemoryPlanAttr AttachMemoryPlanAttr Enable dynamic output planning None
AttachCUDAGraphSymbolicCaptureHints AttachCUDAGraphCaptureHints CUDA graph symbolic vars hints: Dict[str, List[str]]
AttachPipelineParallelStages AttachPipelineParallelStages Pipeline parallelism config pipeline_parallel_shards: int
AttachSequenceLengthPaddingFactor AttachSequenceLengthPaddingFactor Sequence padding for SM100a target, metadata

Dependencies

  • tvm -- Core TVM framework (IRModule, relax, tir, ir.Op)
  • tvm.relax.expr_functor -- PyExprVisitor and @visitor decorator for AST traversal
  • math.lcm -- Used for computing the least common multiple of padding factors

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment