Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mlc ai Mlc llm Attach Spec Decode Aux

From Leeroopedia


Overview

AttachSpecDecodeAuxFuncs is a TVM compiler pass in the MLC-LLM project that attaches auxiliary TIR (Tensor IR) functions required for speculative decoding to the IRModule. Speculative decoding is a technique that uses a smaller "draft" model to predict multiple tokens, which a larger "verifier" model then validates in parallel, improving inference throughput. This pass adds the scatter and gather primitives that are essential for manipulating probability distributions and hidden states during that process.

File: python/mlc_llm/compiler_pass/attach_spec_decode_aux_funcs.py

Purpose

The pass serves two primary purposes:

  1. Attach scatter_probs and gather_probs TIR functions for manipulating probability distributions during speculative decoding. These are always added.
  2. Conditionally attach scatter_hidden_states and gather_hidden_states Relax functions when the model exposes a prefill_to_last_hidden_states function. These higher-level functions wrap corresponding TIR primitives and include tensor parallelism support via collective communication (broadcast from worker 0).

Class: AttachSpecDecodeAuxFuncs

The main pass is registered as a TVM module pass at optimization level 0.

@tvm.transform.module_pass(opt_level=0, name="AttachSpecDecodeAuxFuncs")
class AttachSpecDecodeAuxFuncs:
    tensor_parallel_shards: int

    def __init__(self, tensor_parallel_shards: int):
        self.tensor_parallel_shards = tensor_parallel_shards

Parameters

Parameter Type Description
tensor_parallel_shards int Number of tensor parallel shards. When greater than 1, indices are broadcast from worker 0 before scatter/gather operations.

transform_module

The entry point of the pass. It clones the module and uses a BlockBuilder to add functions:

def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
    mod = mod.clone()
    bb = BlockBuilder(mod)
    bb.add_func(
        _get_scatter_2d_inplace(dtype="float32", global_symbol="scatter_probs"),
        "scatter_probs",
    )
    bb.add_func(
        _get_gather_2d_inplace(dtype="float32", global_symbol="gather_probs"),
        "gather_probs",
    )
    if "prefill_to_last_hidden_states" in mod:
        hidden_states_struct_info = mod["prefill_to_last_hidden_states"].ret_struct_info.fields[0]
        dtype = hidden_states_struct_info.dtype
        _add_gather_hidden_states(bb, self.tensor_parallel_shards, dtype)
        _add_scatter_hidden_states(bb, self.tensor_parallel_shards, dtype)
    return bb.finalize()

The function unconditionally adds scatter_probs and gather_probs (both operating on float32 data). If the module contains prefill_to_last_hidden_states, it also infers the hidden state dtype from the return type and adds the hidden state scatter/gather functions.

TIR Primitive Functions

_get_scatter_2d_inplace

Generates a TIR PrimFunc that performs an in-place 2D scatter operation. Given a source tensor of shape (batch_size, n) and an index tensor of shape (batch_size,), it scatters source rows into the destination tensor of shape (m, n) at positions specified by indices:

def _get_scatter_2d_inplace(dtype: str, global_symbol: str):
    @T.prim_func
    def _scatter_2d(var_src: T.handle, var_indices: T.handle, var_dst: T.handle):
        T.func_attr({"global_symbol": global_symbol, "tir.noalias": True})
        batch_size = T.int32(is_size_var=True)
        m = T.int32(is_size_var=True)
        n = T.int32(is_size_var=True)
        src = T.match_buffer(var_src, (batch_size, n), dtype)
        indices = T.match_buffer(var_indices, (batch_size,), "int32")
        dst = T.match_buffer(var_dst, (m, n), dtype)
        for b, j in T.grid(batch_size, n):
            with T.sblock("scatter_2d"):
                vb, vj = T.axis.remap("SS", [b, j])
                dst[indices[vb], vj] = src[vb, vj]
    return _scatter_2d

_get_gather_2d_inplace

Generates the complementary TIR PrimFunc for in-place 2D gather. Given a source tensor of shape (m, n) and index tensor of shape (batch_size,), it gathers rows from the source into a destination tensor of shape (batch_size, n):

dst[vb, vj] = src[indices[vb], vj]

The logic mirrors scatter but reverses the direction of data movement.

Relax Wrapper Functions

_add_scatter_hidden_states

Builds a Relax function scatter_hidden_states that wraps the scatter TIR primitive. When tensor parallelism is enabled (tensor_parallel_shards > 1), it broadcasts the index tensor from worker 0 before calling the TIR function via relax.op.call_tir_inplace:

def _add_scatter_hidden_states(bb: BlockBuilder, tensor_parallel_shards: int, dtype: str):
    batch_size = tir.SizeVar("batch_size", "int64")
    m = tir.SizeVar("m", "int64")
    n = tir.SizeVar("n", "int64")
    src = relax.Var("src", struct_info=TensorStructInfo([batch_size, n], dtype))
    indices = relax.Var("indices", struct_info=TensorStructInfo([batch_size], "int32"))
    dst = relax.Var("dst", struct_info=TensorStructInfo([m, n], dtype))
    with bb.function("scatter_hidden_states", [src, indices, dst]):
        with bb.dataflow():
            if tensor_parallel_shards > 1:
                indices = relax.op.ccl.broadcast_from_worker0(indices)
            output = bb.emit_output(
                relax.op.call_tir_inplace(...)
            )
        gv = bb.emit_func_output(output)
    return gv

_add_gather_hidden_states

Builds a Relax function gather_hidden_states with the same structure as scatter but performing the inverse gather operation. It also supports tensor parallelism by broadcasting indices.

Function Summary

Function Name Type Dtype Description
scatter_probs TIR PrimFunc float32 Scatters probability rows to indexed positions
gather_probs TIR PrimFunc float32 Gathers probability rows from indexed positions
scatter_hidden_states Relax Function model hidden dtype Scatters hidden state rows with tensor parallel support
gather_hidden_states Relax Function model hidden dtype Gathers hidden state rows with tensor parallel support

Tensor Parallelism Support

When tensor_parallel_shards > 1, the hidden state functions insert a broadcast_from_worker0 collective communication call before the scatter/gather operation. This ensures all workers receive consistent index arrays during distributed inference. The probability functions (scatter_probs and gather_probs) do not include this broadcast step, as they are lower-level TIR functions.

Dependencies

  • tvm -- Core TVM framework including IRModule, relax, and tir modules
  • tvm.relax.BlockBuilder -- Used to construct and finalize the modified IRModule
  • tvm.script.tir -- TIR script DSL for defining PrimFuncs

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment