Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mlc ai Mlc llm Pipeline Parallel Rewrite

From Leeroopedia


Knowledge Sources
Domains Compiler Pass, Pipeline Parallelism, Distributed Inference
Last Updated 2026-02-09 19:00 GMT

Overview

A TVM compiler pass that rewrites IR functions annotated with pipeline parallelism into separate stage functions connected by inter-group send/receive operations for distributed inference across multiple GPU groups.

Description

The PipelineParallelRewrite pass is a TVM module-level transformation (opt_level=0) that converts a monolithic model function into multiple pipeline-parallel stage functions. Each stage runs on a different Disco worker group and communicates intermediate results via send/receive primitives.

The pass operates in several phases:

Stage Extraction (_extract_pipeline_stages): The pass scans the function's dataflow block for calls to mlc.pipeline_parallel_stage_boundary, which serve as markers between pipeline stages. Each boundary call's arguments identify the tensors being sent to the next stage. The code between boundaries forms a stage. For each stage, the pass tracks which variables are received from the previous stage and which are sent to the next stage.

Required Parameter Analysis (_analyze_required_func_params): For each stage, the pass uses an IR visitor (_RequiredFuncParamAnalyzer) to determine which of the original function's parameters are actually used in that stage. This avoids passing unnecessary data to each stage function.

Stage Function Creation (_create_stage_func): The _PipelineParallelRewriter (a PyExprMutator) creates a new function for each stage. Each stage function:

  • Receives tensors from the previous stage via runtime.disco.recv_from_prev_group.
  • Processes its portion of the computation by visiting and transforming the original bindings.
  • Sends output tensors to the next stage via runtime.disco.send_to_next_group.
  • Includes only the required function parameters plus shape variables and packed model parameters.

Shape Variable Handling: When bindings reference shape variables (TIR variables used in tensor dimensions) that are not defined as function parameters, the pass creates new shape variables and maps them through the undefined_shape_vars_remap. These are collected into a shape parameter passed to each stage function. Special handling is provided for shape variables that appear in packed parameter unpacking operations via undefined_param_shape_vars_remap.

Packed Parameter Handling: The original function's packed_params parameter is remapped in each stage. When a stage unpacks a parameter via TupleGetItem, the pass emits a vm.builtin.tuple_getitem call with appropriate struct info, optionally adding a match_cast when new shape variables are introduced.

Entry Function Rewrite: The original function is replaced with a dispatch function that calls mlc.multi_gpu.DispatchFunctionByGroup, which routes execution to the appropriate stage function based on the current Disco worker group ID.

The pass preserves function attributes (like "num_input") on each stage function while removing "global_symbol" and "pipeline_parallel_stages" attributes from the stage functions.

Usage

This pass is applied during MLC LLM model compilation when the model is annotated for pipeline parallelism. It enables splitting a large model across multiple GPU groups, where each group processes one pipeline stage and communicates intermediate activations with neighboring groups. This is essential for serving models too large to fit on a single GPU or GPU group.

Code Reference

Source Location

Signature

@tvm.transform.module_pass(opt_level=0, name="PipelineParallelRewrite")
class PipelineParallelRewrite:
    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
        ...

# Internal rewriter (PyExprMutator)
class _PipelineParallelRewriter(PyExprMutator):
    def transform(self) -> IRModule: ...
    def _create_stage_func(self, func_name, stage_bindings, required_func_params,
                           stage_receive_vars, stage_send_vars, func_attrs,
                           func_output) -> Tuple[tvm.ir.GlobalVar, List[relax.Expr]]: ...

# Internal helpers
def _extract_pipeline_stages(func) -> Tuple[List[List[relax.Binding]],
                                             List[List[relax.Var]],
                                             List[List[relax.Var]]]: ...
def _analyze_required_func_params(pipeline_stages, func_params) -> List[List[relax.Var]]: ...

Import

from mlc_llm.compiler_pass.pipeline_parallel_rewrite import PipelineParallelRewrite

I/O Contract

Inputs

Name Type Required Description
mod IRModule Yes The TVM IRModule containing functions annotated with "pipeline_parallel_stages" attribute.

The input functions must satisfy these requirements:

Requirement Description
pipeline_parallel_stages attr Integer attribute specifying the number of pipeline stages.
num_input attr Integer specifying the number of regular input parameters (before packed_params).
Single dataflow block The function body must have exactly one dataflow block.
Stage boundary markers Calls to "mlc.pipeline_parallel_stage_boundary" must separate the stages.
packed_params parameter The last parameter must be named "packed_params" with ObjectStructInfo.

Outputs

Name Type Description
mod IRModule The transformed IRModule with the original function replaced by a dispatch function and N stage functions (one per pipeline stage).

Each stage function in the output module:

  • Receives tensors via runtime.disco.recv_from_prev_group.
  • Processes its computation slice.
  • Sends tensors via runtime.disco.send_to_next_group.

Usage Examples

import tvm
from mlc_llm.compiler_pass.pipeline_parallel_rewrite import PipelineParallelRewrite

# Apply the pass to an IRModule with pipeline-parallel annotations
pp_pass = PipelineParallelRewrite()
with tvm.transform.PassContext():
    rewritten_mod = pp_pass(mod)

# The resulting module will contain:
# - A dispatch function (replaces the original) that calls
#   mlc.multi_gpu.DispatchFunctionByGroup
# - Stage functions like "prefill_stage0", "prefill_stage1", etc.
# - Each stage function uses runtime.disco.recv_from_prev_group
#   and runtime.disco.send_to_next_group for inter-group communication

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment