Implementation:Mlc ai Mlc llm Top P Pivot
| Knowledge Sources | |
|---|---|
| Domains | Sampling, TVM Tensor IR, LLM Operators |
| Last Updated | 2026-02-09 19:00 GMT |
Overview
TVM TIR implementations for top-p (nucleus) sampling that find the probability pivot and renormalize the probability distribution for token generation in LLMs.
Description
The top_p_pivot module provides two core GPU kernel functions for top-p (nucleus) sampling:
- top_p_pivot: A binary-search-like algorithm that finds the pivot probability value to cut off the top-p percentile of the probability distribution. The function iteratively narrows the search range [L, R] using multiple pivot candidates (pN pivots uniformly spaced between L and R). For each pivot candidate, it computes lsum (sum of probabilities >= pivot), lmin (minimum probability >= pivot), and cmin (count of elements equal to lmin). A valid pivot satisfies: lsum >= top_p AND top_p > lsum - cmin * lmin. The search uses cross-thread reductions with early stopping every K iterations. The kernel employs shared memory for thread communication and local memory for per-thread computations.
- top_p_renorm: After the pivot is determined, this kernel renormalizes the probability distribution by keeping values >= pivot (dividing by lsum) and zeroing out values below the pivot. This creates a valid probability distribution for downstream sampling.
Both kernels are parameterized by the target device to respect thread block size limits, using get_max_num_threads_per_block for hardware-aware configuration.
Usage
Use these operators as part of the token sampling pipeline in MLC LLM. The top_p_pivot function is invoked during text generation to implement nucleus sampling, where only tokens within the top-p cumulative probability mass are considered. The top_p_renorm function is then applied to produce a valid renormalized distribution from which the next token is sampled.
Code Reference
Source Location
- Repository: Mlc_ai_Mlc_llm
- File: python/mlc_llm/op/top_p_pivot.py
Signature
def top_p_pivot(pN, target: tvm.target.Target) -> tvm.tir.PrimFunc
def top_p_renorm(target: tvm.target.Target = None) -> tvm.tir.PrimFunc
Import
from mlc_llm.op.top_p_pivot import top_p_pivot, top_p_renorm
I/O Contract
top_p_pivot
| Parameter | Type | Description |
|---|---|---|
| pN | int | Number of pivot candidates per iteration |
| target | tvm.target.Target | The TVM compilation target (determines max thread block size) |
Returned PrimFunc buffers:
| Buffer | Shape | Dtype | Direction | Description |
|---|---|---|---|---|
| prob | (B, N) | float32 | Input | Probability distribution for each batch element |
| top_p_arr | (B,) | float32 | Input | Per-batch top-p threshold values |
| init_pivots | (B, pN) | float32 | Input | Initial pivot candidates (descending order) |
| final_pivot | (B,) | float32 | Output | The computed pivot value per batch element |
| final_lsum | (B,) | float32 | Output | Sum of probabilities >= pivot per batch element |
top_p_renorm
Returned PrimFunc buffers:
| Buffer | Shape | Dtype | Direction | Description |
|---|---|---|---|---|
| prob | (B, N) | float32 | Input | Original probability distribution |
| final_pivot | (B,) | float32 | Input | The pivot cutoff from top_p_pivot |
| final_lsum | (B,) | float32 | Input | Sum of probabilities >= pivot |
| renorm_prob | (B, N) | float32 | Output | Renormalized probabilities (0 if below pivot) |
Algorithm Details
The top_p_pivot kernel uses an iterative multi-pivot binary search:
- Initialize search bounds: L = 1 - top_p, R = epsilon
- At each iteration, place pN pivots uniformly between L and R
- For each pivot candidate, reduce across all threads to compute lsum, lmin, cmin
- Check validity: if lsum >= top_p AND top_p > lsum - cmin * lmin, pivot is found
- Otherwise, narrow the search range and repeat
- Early stopping is applied every K=32 thread-level iterations by checking if the remaining probability mass is below the smallest pivot
Usage Examples
import tvm
# Create the top-p pivot TIR function for a GPU target
target = tvm.target.Target("cuda")
pivot_func = top_p_pivot(pN=4, target=target)
# Create the renormalization function
renorm_func = top_p_renorm(target=target)
# These PrimFuncs are typically composed into a TVM IRModule
# and compiled for GPU execution as part of the sampling pipeline.