Implementation:Mlc ai Mlc llm Attach Spec Decode Aux
Overview
AttachSpecDecodeAuxFuncs is a TVM compiler pass in the MLC-LLM project that attaches auxiliary TIR (Tensor IR) functions required for speculative decoding to the IRModule. Speculative decoding is a technique that uses a smaller "draft" model to predict multiple tokens, which a larger "verifier" model then validates in parallel, improving inference throughput. This pass adds the scatter and gather primitives that are essential for manipulating probability distributions and hidden states during that process.
File: python/mlc_llm/compiler_pass/attach_spec_decode_aux_funcs.py
Purpose
The pass serves two primary purposes:
- Attach scatter_probs and gather_probs TIR functions for manipulating probability distributions during speculative decoding. These are always added.
- Conditionally attach scatter_hidden_states and gather_hidden_states Relax functions when the model exposes a
prefill_to_last_hidden_statesfunction. These higher-level functions wrap corresponding TIR primitives and include tensor parallelism support via collective communication (broadcast from worker 0).
Class: AttachSpecDecodeAuxFuncs
The main pass is registered as a TVM module pass at optimization level 0.
@tvm.transform.module_pass(opt_level=0, name="AttachSpecDecodeAuxFuncs")
class AttachSpecDecodeAuxFuncs:
tensor_parallel_shards: int
def __init__(self, tensor_parallel_shards: int):
self.tensor_parallel_shards = tensor_parallel_shards
Parameters
| Parameter | Type | Description |
|---|---|---|
tensor_parallel_shards |
int |
Number of tensor parallel shards. When greater than 1, indices are broadcast from worker 0 before scatter/gather operations. |
transform_module
The entry point of the pass. It clones the module and uses a BlockBuilder to add functions:
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
mod = mod.clone()
bb = BlockBuilder(mod)
bb.add_func(
_get_scatter_2d_inplace(dtype="float32", global_symbol="scatter_probs"),
"scatter_probs",
)
bb.add_func(
_get_gather_2d_inplace(dtype="float32", global_symbol="gather_probs"),
"gather_probs",
)
if "prefill_to_last_hidden_states" in mod:
hidden_states_struct_info = mod["prefill_to_last_hidden_states"].ret_struct_info.fields[0]
dtype = hidden_states_struct_info.dtype
_add_gather_hidden_states(bb, self.tensor_parallel_shards, dtype)
_add_scatter_hidden_states(bb, self.tensor_parallel_shards, dtype)
return bb.finalize()
The function unconditionally adds scatter_probs and gather_probs (both operating on float32 data). If the module contains prefill_to_last_hidden_states, it also infers the hidden state dtype from the return type and adds the hidden state scatter/gather functions.
TIR Primitive Functions
_get_scatter_2d_inplace
Generates a TIR PrimFunc that performs an in-place 2D scatter operation. Given a source tensor of shape (batch_size, n) and an index tensor of shape (batch_size,), it scatters source rows into the destination tensor of shape (m, n) at positions specified by indices:
def _get_scatter_2d_inplace(dtype: str, global_symbol: str):
@T.prim_func
def _scatter_2d(var_src: T.handle, var_indices: T.handle, var_dst: T.handle):
T.func_attr({"global_symbol": global_symbol, "tir.noalias": True})
batch_size = T.int32(is_size_var=True)
m = T.int32(is_size_var=True)
n = T.int32(is_size_var=True)
src = T.match_buffer(var_src, (batch_size, n), dtype)
indices = T.match_buffer(var_indices, (batch_size,), "int32")
dst = T.match_buffer(var_dst, (m, n), dtype)
for b, j in T.grid(batch_size, n):
with T.sblock("scatter_2d"):
vb, vj = T.axis.remap("SS", [b, j])
dst[indices[vb], vj] = src[vb, vj]
return _scatter_2d
_get_gather_2d_inplace
Generates the complementary TIR PrimFunc for in-place 2D gather. Given a source tensor of shape (m, n) and index tensor of shape (batch_size,), it gathers rows from the source into a destination tensor of shape (batch_size, n):
dst[vb, vj] = src[indices[vb], vj]
The logic mirrors scatter but reverses the direction of data movement.
Relax Wrapper Functions
Builds a Relax function scatter_hidden_states that wraps the scatter TIR primitive. When tensor parallelism is enabled (tensor_parallel_shards > 1), it broadcasts the index tensor from worker 0 before calling the TIR function via relax.op.call_tir_inplace:
def _add_scatter_hidden_states(bb: BlockBuilder, tensor_parallel_shards: int, dtype: str):
batch_size = tir.SizeVar("batch_size", "int64")
m = tir.SizeVar("m", "int64")
n = tir.SizeVar("n", "int64")
src = relax.Var("src", struct_info=TensorStructInfo([batch_size, n], dtype))
indices = relax.Var("indices", struct_info=TensorStructInfo([batch_size], "int32"))
dst = relax.Var("dst", struct_info=TensorStructInfo([m, n], dtype))
with bb.function("scatter_hidden_states", [src, indices, dst]):
with bb.dataflow():
if tensor_parallel_shards > 1:
indices = relax.op.ccl.broadcast_from_worker0(indices)
output = bb.emit_output(
relax.op.call_tir_inplace(...)
)
gv = bb.emit_func_output(output)
return gv
Builds a Relax function gather_hidden_states with the same structure as scatter but performing the inverse gather operation. It also supports tensor parallelism by broadcasting indices.
Function Summary
| Function Name | Type | Dtype | Description |
|---|---|---|---|
scatter_probs |
TIR PrimFunc | float32 | Scatters probability rows to indexed positions |
gather_probs |
TIR PrimFunc | float32 | Gathers probability rows from indexed positions |
scatter_hidden_states |
Relax Function | model hidden dtype | Scatters hidden state rows with tensor parallel support |
gather_hidden_states |
Relax Function | model hidden dtype | Gathers hidden state rows with tensor parallel support |
Tensor Parallelism Support
When tensor_parallel_shards > 1, the hidden state functions insert a broadcast_from_worker0 collective communication call before the scatter/gather operation. This ensures all workers receive consistent index arrays during distributed inference. The probability functions (scatter_probs and gather_probs) do not include this broadcast step, as they are lower-level TIR functions.
Dependencies
tvm-- Core TVM framework includingIRModule,relax, andtirmodulestvm.relax.BlockBuilder-- Used to construct and finalize the modified IRModuletvm.script.tir-- TIR script DSL for defining PrimFuncs