Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mlc ai Mlc llm Batch Spec Verify

From Leeroopedia


Overview

The Batch Spec Verify module implements a GPU kernel for batch-level token tree verification in speculative decoding. It is located at python/mlc_llm/op/batch_spec_verify.py (177 lines).

Speculative decoding is an optimization technique where a smaller "draft" model proposes multiple token candidates organized as a tree, and a larger "target" model verifies them. This module provides the TIR (Tensor IR) primitive function that walks the token tree, accepting or rejecting draft tokens based on probability comparison, and renormalizes the model probabilities when rejections occur.

Source File

  • File: python/mlc_llm/op/batch_spec_verify.py
  • Lines: 177
  • Module: mlc_llm.op.batch_spec_verify

Dependencies

Import Purpose
tvm.script.tir TIR script decorator and primitives for defining the GPU kernel

Function: batch_spec_verify

def batch_spec_verify(vocab_size):

Returns a TIR primitive function parameterized by vocab_size. The returned function performs batched tree-based speculative verification on GPU.

Parameters (of returned TIR function)

Buffer Shape dtype Description
draft_probs (num_nodes, vocab_size) float32 Draft model probability distribution at each tree node
draft_tokens (num_nodes,) int32 Token ID proposed at each tree node
model_probs (num_nodes, vocab_size) float32 Target model probability distribution at each parent node (updated in-place)
token_tree_first_child (num_nodes,) int32 First child pointer for each node (-1 if leaf)
token_tree_next_sibling (num_nodes,) int32 Next sibling pointer for each node (-1 if none)
uniform_samples (num_nodes,) float32 Per-node uniform random samples for acceptance testing
token_tree_parent_ptr (nbatch,) int32 Current parent pointer per batch (input: tree root; output: last accepted node)

Storage Convention

An important design detail is the difference between how draft_probs and model_probs are indexed:

  • draft_probs[node_id, :] stores the probability that was used to sample the token at that node.
  • model_probs[node_id, :] stores the target model's probability at a node, used to verify its children.

This asymmetry exists because different child tokens may have been sampled with different draft probabilities, but the target model probability is unique per parent position.

GPU Execution Model

TX = 1024  # threads per block for vocabulary parallelism

for _bx in T.thread_binding(0, nbatch, thread="blockIdx.x"):
    for _tx in T.thread_binding(0, TX, thread="threadIdx.x"):
  • blockIdx.x: Each block processes one batch element.
  • threadIdx.x: 1024 threads per block collaborate on vocabulary-sized operations (renormalization).

Verification Algorithm

The kernel implements a while-loop that walks the token tree from root to leaf:

Step 1: Acceptance Test

For the current child node, thread 0 computes the acceptance criterion:

if tx == 0:
    child_token[0] = draft_tokens[child_ptr[0]]
    p_child[0] = model_probs[parent_ptr[0], child_token[0]]
    q_child[0] = draft_probs[child_ptr[0], child_token[0]]
    uniform_sample[0] = uniform_samples[child_ptr[0]]
    pred_shared[0] = p_child[0] >= uniform_sample[0] * q_child[0]

The acceptance condition is p(x) >= u * q(x), where:

  • p(x) is the target model probability for token x
  • q(x) is the draft model probability for token x
  • u is a uniform random sample in [0, 1)

This uses multiplication (u * q) rather than division (p / q) to avoid division-by-zero issues.

Step 2: Accept

If the acceptance test passes, the algorithm advances deeper into the tree:

if pred_local[0]:
    parent_ptr[0] = child_ptr[0]
    child_ptr[0] = token_tree_first_child[child_ptr[0]]

Step 3: Reject and Renormalize

If the token is rejected, the model probabilities at the parent are renormalized by subtracting the draft probabilities and re-normalizing:

# Phase 1: Compute adjusted probabilities and their sum
for i in T.serial(T.ceildiv(vocab_size, TX)):
    k = i * TX + tx
    if k < vocab_size:
        model_prob_local[0] = model_probs[parent_ptr[0], k]
        draft_prob_local[0] = draft_probs[child_ptr[0], k]
        model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], 0.0)
        psum[0] += model_prob_local[0]

The partial sums across threads are reduced using a cross-thread reduction:

with T.sblock("block_cross_thread"):
    T.attr(
        T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
        "reduce_scope",
        T.reinterpret("handle", T.uint64(0)),
    )
    T.tvm_thread_allreduce(T.uint32(1), psum[0], True, t0[0], tx, dtype="handle")

If the renormalized sum is near zero (< 1e-7), the token is accepted despite the initial rejection (to avoid numerical issues). Otherwise, the renormalized probabilities are written back:

if t0[0] < 1e-7:
    # Accept anyway (degenerate case)
    parent_ptr[0] = child_ptr[0]
    child_ptr[0] = token_tree_first_child[child_ptr[0]]
else:
    # Write renormalized probabilities and move to next sibling
    for i in T.serial(T.ceildiv(vocab_size, TX)):
        k = i * TX + tx
        if k < vocab_size:
            model_probs[parent_ptr[0], k] = model_prob_local[0] / t0[0]
    child_ptr[0] = token_tree_next_sibling[child_ptr[0]]

Step 4: Finalize

After the loop completes (no more children to explore), thread 0 writes the final parent pointer:

if tx == 0:
    token_tree_parent_ptr[b] = parent_ptr[0]

Synchronization

The kernel uses T.tvm_storage_sync("shared") at critical points to ensure:

  1. All threads see the acceptance predicate computed by thread 0.
  2. All reads of model_probs are complete before any writes during renormalization.
  3. Shared state is consistent before loop exit checks.

Post-Verification Sampling

After this function returns, a follow-up sampling step must sample from model_probs[token_tree_parent_ptr[b], :] to generate one additional token beyond the last accepted position.

Categories

  • Speculative Decoding
  • GPU Kernels
  • TVM TIR
  • Token Tree Verification
  • LLM Inference Optimization

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment