Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:InternLM Lmdeploy SamplingTopkKernels

From Leeroopedia


Knowledge Sources
Domains GPU_Kernels, Sampling
Last Updated 2026-02-07 15:00 GMT

Overview

CUDA kernels for top-K sampling including batch top-K with optional top-P threshold, cuRAND state initialization, and top-K sort-and-filter operations.

Description

This header provides the top-K sampling pipeline. invokeBatchTopKSampling() performs the full top-K sampling operation: it sorts logits to find the top-K candidates per sequence (with per-sequence K values), optionally applies a top-P probability threshold, draws a sample using cuRAND, and updates output IDs, sequence lengths, cumulative log probabilities, and sampled log-probs. It uses a two-pass workspace pattern (first call with nullptr workspace to query size, second to execute). InitializeRandomStates() initializes cuRAND states from per-sequence random seeds with an optional mask. TopKSortFilterParams and invokeTopKSortFilter() provide a standalone top-K sort-and-filter step that produces sorted logits, indices, and kept counts, suitable for use before a separate sampling step.

Usage

Use these kernels when the sampling strategy requires selecting from the top-K most probable tokens, optionally combined with a top-P threshold for nucleus sampling.

Code Reference

Source Location

Signature

template<typename T>
void invokeBatchTopKSampling(
    void* workspace, size_t& workspace_size, const T* log_probs,
    int* ids, int* sequence_length, bool* finished,
    float* cum_log_probs, float* output_log_probs,
    float* sampled_logprobs, uint32_t* sampled_indexes, uint32_t* sampled_nums,
    curandState_t* curandstate, const int max_top_k, const int* top_ks,
    const float top_p, const float* top_ps,
    const int vocab_size_padded, const int* end_ids,
    cudaStream_t stream, const int batch_size, const bool* skip_decode);

void InitializeRandomStates(curandState_t* states, const uint64_t* random_seeds,
    const bool* mask, size_t batch_size, cudaStream_t stream);

struct TopKSortFilterParams {
    void* logits; void* sorted_logits; int* sorted_indices;
    int* kept; int* top_ks; int max_top_k;
    int batch_size; int vocab_size; int vocab_size_padded;
};

template<typename T>
void invokeTopKSortFilter(TopKSortFilterParams& params, cudaStream_t stream);

Import

#include "src/turbomind/kernels/sampling_topk_kernels.h"

I/O Contract

Inputs

Name Type Required Description
log_probs const T* Yes Log probability distribution over vocabulary
max_top_k int Yes Maximum K value across the batch
top_ks const int* Yes Per-sequence K values
top_ps const float* No Per-sequence top-P thresholds (optional)
curandstate curandState_t* Yes Per-sequence random states
end_ids const int* Yes End-of-sequence token ID(s)
batch_size int Yes Number of sequences

Outputs

Name Type Description
ids int* Sampled output token IDs
sequence_length int* Updated sequence lengths
finished bool* Per-sequence finished flags
cum_log_probs float* Updated cumulative log probabilities
workspace_size size_t& Required workspace size (query mode)

Usage Examples

using namespace turbomind;

// Query workspace size
size_t ws_size = 0;
invokeBatchTopKSampling<half>(nullptr, ws_size, log_probs,
    nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
    rand_states, max_k, top_ks, 0.9f, top_ps,
    vocab_padded, end_ids, stream, batch_size, nullptr);

// Allocate workspace, then execute
invokeBatchTopKSampling<half>(workspace, ws_size, log_probs,
    output_ids, seq_lengths, finished, cum_logprobs, out_logprobs,
    sampled_lp, sampled_idx, sampled_n,
    rand_states, max_k, top_ks, 0.9f, top_ps,
    vocab_padded, end_ids, stream, batch_size, skip_decode);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment