Implementation:Mlc ai Mlc llm Batch Spec Verify
Overview
The Batch Spec Verify module implements a GPU kernel for batch-level token tree verification in speculative decoding. It is located at python/mlc_llm/op/batch_spec_verify.py (177 lines).
Speculative decoding is an optimization technique where a smaller "draft" model proposes multiple token candidates organized as a tree, and a larger "target" model verifies them. This module provides the TIR (Tensor IR) primitive function that walks the token tree, accepting or rejecting draft tokens based on probability comparison, and renormalizes the model probabilities when rejections occur.
Source File
- File:
python/mlc_llm/op/batch_spec_verify.py - Lines: 177
- Module:
mlc_llm.op.batch_spec_verify
Dependencies
| Import | Purpose |
|---|---|
tvm.script.tir |
TIR script decorator and primitives for defining the GPU kernel |
Function: batch_spec_verify
def batch_spec_verify(vocab_size):
Returns a TIR primitive function parameterized by vocab_size. The returned function performs batched tree-based speculative verification on GPU.
Parameters (of returned TIR function)
| Buffer | Shape | dtype | Description |
|---|---|---|---|
draft_probs |
(num_nodes, vocab_size) |
float32 | Draft model probability distribution at each tree node |
draft_tokens |
(num_nodes,) |
int32 | Token ID proposed at each tree node |
model_probs |
(num_nodes, vocab_size) |
float32 | Target model probability distribution at each parent node (updated in-place) |
token_tree_first_child |
(num_nodes,) |
int32 | First child pointer for each node (-1 if leaf) |
token_tree_next_sibling |
(num_nodes,) |
int32 | Next sibling pointer for each node (-1 if none) |
uniform_samples |
(num_nodes,) |
float32 | Per-node uniform random samples for acceptance testing |
token_tree_parent_ptr |
(nbatch,) |
int32 | Current parent pointer per batch (input: tree root; output: last accepted node) |
Storage Convention
An important design detail is the difference between how draft_probs and model_probs are indexed:
draft_probs[node_id, :]stores the probability that was used to sample the token at that node.model_probs[node_id, :]stores the target model's probability at a node, used to verify its children.
This asymmetry exists because different child tokens may have been sampled with different draft probabilities, but the target model probability is unique per parent position.
GPU Execution Model
TX = 1024 # threads per block for vocabulary parallelism
for _bx in T.thread_binding(0, nbatch, thread="blockIdx.x"):
for _tx in T.thread_binding(0, TX, thread="threadIdx.x"):
blockIdx.x: Each block processes one batch element.threadIdx.x: 1024 threads per block collaborate on vocabulary-sized operations (renormalization).
Verification Algorithm
The kernel implements a while-loop that walks the token tree from root to leaf:
Step 1: Acceptance Test
For the current child node, thread 0 computes the acceptance criterion:
if tx == 0:
child_token[0] = draft_tokens[child_ptr[0]]
p_child[0] = model_probs[parent_ptr[0], child_token[0]]
q_child[0] = draft_probs[child_ptr[0], child_token[0]]
uniform_sample[0] = uniform_samples[child_ptr[0]]
pred_shared[0] = p_child[0] >= uniform_sample[0] * q_child[0]
The acceptance condition is p(x) >= u * q(x), where:
p(x)is the target model probability for tokenxq(x)is the draft model probability for tokenxuis a uniform random sample in [0, 1)
This uses multiplication (u * q) rather than division (p / q) to avoid division-by-zero issues.
Step 2: Accept
If the acceptance test passes, the algorithm advances deeper into the tree:
if pred_local[0]:
parent_ptr[0] = child_ptr[0]
child_ptr[0] = token_tree_first_child[child_ptr[0]]
Step 3: Reject and Renormalize
If the token is rejected, the model probabilities at the parent are renormalized by subtracting the draft probabilities and re-normalizing:
# Phase 1: Compute adjusted probabilities and their sum
for i in T.serial(T.ceildiv(vocab_size, TX)):
k = i * TX + tx
if k < vocab_size:
model_prob_local[0] = model_probs[parent_ptr[0], k]
draft_prob_local[0] = draft_probs[child_ptr[0], k]
model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], 0.0)
psum[0] += model_prob_local[0]
The partial sums across threads are reduced using a cross-thread reduction:
with T.sblock("block_cross_thread"):
T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
"reduce_scope",
T.reinterpret("handle", T.uint64(0)),
)
T.tvm_thread_allreduce(T.uint32(1), psum[0], True, t0[0], tx, dtype="handle")
If the renormalized sum is near zero (< 1e-7), the token is accepted despite the initial rejection (to avoid numerical issues). Otherwise, the renormalized probabilities are written back:
if t0[0] < 1e-7:
# Accept anyway (degenerate case)
parent_ptr[0] = child_ptr[0]
child_ptr[0] = token_tree_first_child[child_ptr[0]]
else:
# Write renormalized probabilities and move to next sibling
for i in T.serial(T.ceildiv(vocab_size, TX)):
k = i * TX + tx
if k < vocab_size:
model_probs[parent_ptr[0], k] = model_prob_local[0] / t0[0]
child_ptr[0] = token_tree_next_sibling[child_ptr[0]]
Step 4: Finalize
After the loop completes (no more children to explore), thread 0 writes the final parent pointer:
if tx == 0:
token_tree_parent_ptr[b] = parent_ptr[0]
Synchronization
The kernel uses T.tvm_storage_sync("shared") at critical points to ensure:
- All threads see the acceptance predicate computed by thread 0.
- All reads of
model_probsare complete before any writes during renormalization. - Shared state is consistent before loop exit checks.
Post-Verification Sampling
After this function returns, a follow-up sampling step must sample from model_probs[token_tree_parent_ptr[b], :] to generate one additional token beyond the last accepted position.
Categories
- Speculative Decoding
- GPU Kernels
- TVM TIR
- Token Tree Verification
- LLM Inference Optimization