Implementation:Mlc ai Mlc llm Attach Support Info
Overview
attach_support_info.py contains a collection of TVM compiler passes that attach various supportive metadata and attributes to functions in an IRModule. These passes operate at optimization level 0 and are used during the MLC-LLM compilation pipeline to annotate Relax and TIR functions with information needed by downstream passes such as memory planning, CUDA graph capture, pipeline parallelism, and sequence length padding.
File: python/mlc_llm/compiler_pass/attach_support_info.py
Passes
AttachVariableBounds
Attaches TIR variable upper bounds to each Relax function, which primarily assists the memory planning pass in determining buffer sizes.
@tvm.transform.module_pass(opt_level=0, name="AttachVariableBounds")
class AttachVariableBounds:
def __init__(self, variable_bounds: Dict[str, int]):
# Specifically for RWKV workloads, which contains -1 max_seq_len
self.variable_bounds = {k: v for k, v in variable_bounds.items() if v > 0}
self.non_negative_var = ["vocab_size"]
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
for g_var, func in mod.functions_items():
if isinstance(func, relax.Function):
mod[g_var] = func.with_attr("tir_var_upper_bound", self.variable_bounds).with_attr(
"tir_non_negative_var", self.non_negative_var
)
return mod
Key details:
- Filters out non-positive bounds (specifically handling RWKV models that may report
-1formax_seq_len) - Attaches two attributes:
tir_var_upper_bound(the bounds dictionary) andtir_non_negative_var(hardcoded to["vocab_size"])
AttachAdditionalPrimFuncs
Attaches extra TIR PrimFuncs directly to the IRModule. Each function is annotated with its global_symbol attribute.
@tvm.transform.module_pass(opt_level=0, name="AttachAdditionalPrimFuncs")
class AttachAdditionalPrimFuncs:
def __init__(self, functions: Dict[str, tir.PrimFunc]):
self.functions = functions
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
for func_name, func in self.functions.items():
mod[func_name] = func.with_attr("global_symbol", func_name)
return mod
AttachMemoryPlanAttr
Annotates all Relax functions to enable dynamic function output memory planning. This is a simple pass that adds the attribute relax.memory_plan_dynamic_func_output set to True on every Relax function.
@tvm.transform.module_pass(opt_level=0, name="AttachMemoryPlanAttr")
class AttachMemoryPlanAttr:
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
for g_var, func in mod.functions_items():
if isinstance(func, relax.Function):
mod[g_var] = func.with_attr("relax.memory_plan_dynamic_func_output", True)
return mod
AttachCUDAGraphSymbolicCaptureHints
Attaches CUDA graph capture hints to specific Relax functions. These hints guide the CUDA graph rewriting pass by specifying which symbolic variables should be captured during graph recording.
@tvm.transform.module_pass(opt_level=0, name="AttachCUDAGraphCaptureHints")
class AttachCUDAGraphSymbolicCaptureHints:
def __init__(self, hints: Dict[str, List[str]]):
self.hints = hints
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
for g_var, func in mod.functions_items():
func_name = g_var.name_hint
if isinstance(func, relax.Function):
if func_name in self.hints:
mod[g_var] = func.with_attr(
"relax.rewrite_cuda_graph.capture_symbolic_vars",
self.hints[func_name],
)
return mod
The attribute relax.rewrite_cuda_graph.capture_symbolic_vars is a list of symbolic variable names for each function that should be captured when CUDA graphs are used.
AttachPipelineParallelStages
Attaches the number of pipeline parallel stages to relevant Relax functions. This pass only targets a specific set of model inference entry-point functions.
@tvm.transform.module_pass(opt_level=0, name="AttachPipelineParallelStages")
class AttachPipelineParallelStages:
def __init__(self, pipeline_parallel_shards: int):
self.pipeline_parallel_shards = pipeline_parallel_shards
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
for g_var, func in mod.functions_items():
func_name = g_var.name_hint
if not isinstance(func, relax.Function) or func_name not in [
"prefill", "decode",
"prefill_to_last_hidden_states", "decode_to_last_hidden_states",
"batch_prefill", "batch_decode", "batch_verify",
"batch_prefill_to_last_hidden_states",
"batch_decode_to_last_hidden_states",
"batch_verify_to_last_hidden_states",
]:
continue
mod[g_var] = func.with_attr("pipeline_parallel_stages", self.pipeline_parallel_shards)
return mod
Targeted functions:
| Category | Function Names |
|---|---|
| Single-request | prefill, decode
|
| Hidden states | prefill_to_last_hidden_states, decode_to_last_hidden_states
|
| Batched | batch_prefill, batch_decode, batch_verify
|
| Batched hidden states | batch_prefill_to_last_hidden_states, batch_decode_to_last_hidden_states, batch_verify_to_last_hidden_states
|
AttachSequenceLengthPaddingFactor
Determines and attaches a sequence length padding factor to the model metadata. This is specifically relevant for NVIDIA SM100a (Blackwell) architecture when using CUTLASS grouped/scaled GEMM operations.
@tvm.transform.module_pass(opt_level=0, name="AttachSequenceLengthPaddingFactor")
class AttachSequenceLengthPaddingFactor:
def __init__(self, target: tvm.target.Target, metadata: Dict[str, Any]):
self.target = target
self.metadata = metadata
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
# Uses an internal visitor to scan the IRModule
padding_factor = _Visitor(self.target).run(mod)
if padding_factor > 1:
self.metadata["seqlen_padding_factor"] = padding_factor
return mod
The internal _Visitor class inspects all relax.call_dps_packed calls looking for specific CUTLASS kernels:
cutlass.groupwise_scaled_gemm_e4m3fn_e4m3fncutlass.groupwise_scaled_bmm_e4m3fn_e4m3fn
When such kernels are found on SM100a targets, the padding factor is set to the least common multiple of the current factor and 4 (using math.lcm). This ensures sequence lengths are padded to multiples of 4 for these particular CUTLASS operations.
Pass Summary
| Pass Name | TVM Registration Name | Purpose | Parameters |
|---|---|---|---|
AttachVariableBounds |
AttachVariableBounds | Memory planning bounds | variable_bounds: Dict[str, int]
|
AttachAdditionalPrimFuncs |
AttachAdditionalPrimFuncs | Add extra TIR functions | functions: Dict[str, tir.PrimFunc]
|
AttachMemoryPlanAttr |
AttachMemoryPlanAttr | Enable dynamic output planning | None |
AttachCUDAGraphSymbolicCaptureHints |
AttachCUDAGraphCaptureHints | CUDA graph symbolic vars | hints: Dict[str, List[str]]
|
AttachPipelineParallelStages |
AttachPipelineParallelStages | Pipeline parallelism config | pipeline_parallel_shards: int
|
AttachSequenceLengthPaddingFactor |
AttachSequenceLengthPaddingFactor | Sequence padding for SM100a | target, metadata
|
Dependencies
tvm-- Core TVM framework (IRModule,relax,tir,ir.Op)tvm.relax.expr_functor--PyExprVisitorand@visitordecorator for AST traversalmath.lcm-- Used for computing the least common multiple of padding factors