Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mlc ai Mlc llm Top P Pivot

From Leeroopedia


Knowledge Sources
Domains Sampling, TVM Tensor IR, LLM Operators
Last Updated 2026-02-09 19:00 GMT

Overview

TVM TIR implementations for top-p (nucleus) sampling that find the probability pivot and renormalize the probability distribution for token generation in LLMs.

Description

The top_p_pivot module provides two core GPU kernel functions for top-p (nucleus) sampling:

  • top_p_pivot: A binary-search-like algorithm that finds the pivot probability value to cut off the top-p percentile of the probability distribution. The function iteratively narrows the search range [L, R] using multiple pivot candidates (pN pivots uniformly spaced between L and R). For each pivot candidate, it computes lsum (sum of probabilities >= pivot), lmin (minimum probability >= pivot), and cmin (count of elements equal to lmin). A valid pivot satisfies: lsum >= top_p AND top_p > lsum - cmin * lmin. The search uses cross-thread reductions with early stopping every K iterations. The kernel employs shared memory for thread communication and local memory for per-thread computations.
  • top_p_renorm: After the pivot is determined, this kernel renormalizes the probability distribution by keeping values >= pivot (dividing by lsum) and zeroing out values below the pivot. This creates a valid probability distribution for downstream sampling.

Both kernels are parameterized by the target device to respect thread block size limits, using get_max_num_threads_per_block for hardware-aware configuration.

Usage

Use these operators as part of the token sampling pipeline in MLC LLM. The top_p_pivot function is invoked during text generation to implement nucleus sampling, where only tokens within the top-p cumulative probability mass are considered. The top_p_renorm function is then applied to produce a valid renormalized distribution from which the next token is sampled.

Code Reference

Source Location

Signature

def top_p_pivot(pN, target: tvm.target.Target) -> tvm.tir.PrimFunc

def top_p_renorm(target: tvm.target.Target = None) -> tvm.tir.PrimFunc

Import

from mlc_llm.op.top_p_pivot import top_p_pivot, top_p_renorm

I/O Contract

top_p_pivot

Parameter Type Description
pN int Number of pivot candidates per iteration
target tvm.target.Target The TVM compilation target (determines max thread block size)

Returned PrimFunc buffers:

Buffer Shape Dtype Direction Description
prob (B, N) float32 Input Probability distribution for each batch element
top_p_arr (B,) float32 Input Per-batch top-p threshold values
init_pivots (B, pN) float32 Input Initial pivot candidates (descending order)
final_pivot (B,) float32 Output The computed pivot value per batch element
final_lsum (B,) float32 Output Sum of probabilities >= pivot per batch element

top_p_renorm

Returned PrimFunc buffers:

Buffer Shape Dtype Direction Description
prob (B, N) float32 Input Original probability distribution
final_pivot (B,) float32 Input The pivot cutoff from top_p_pivot
final_lsum (B,) float32 Input Sum of probabilities >= pivot
renorm_prob (B, N) float32 Output Renormalized probabilities (0 if below pivot)

Algorithm Details

The top_p_pivot kernel uses an iterative multi-pivot binary search:

  1. Initialize search bounds: L = 1 - top_p, R = epsilon
  2. At each iteration, place pN pivots uniformly between L and R
  3. For each pivot candidate, reduce across all threads to compute lsum, lmin, cmin
  4. Check validity: if lsum >= top_p AND top_p > lsum - cmin * lmin, pivot is found
  5. Otherwise, narrow the search range and repeat
  6. Early stopping is applied every K=32 thread-level iterations by checking if the remaining probability mass is below the smallest pivot

Usage Examples

import tvm

# Create the top-p pivot TIR function for a GPU target
target = tvm.target.Target("cuda")
pivot_func = top_p_pivot(pN=4, target=target)

# Create the renormalization function
renorm_func = top_p_renorm(target=target)

# These PrimFuncs are typically composed into a TVM IRModule
# and compiled for GPU execution as part of the sampling pipeline.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment