Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:InternLM Lmdeploy Sampling

From Leeroopedia


Knowledge Sources
Domains Text Generation, Token Sampling
Last Updated 2026-02-07 15:00 GMT

Overview

Implements the token sampling module that selects output tokens from processed logits using top-k, top-p, and min-p filtering strategies with optional log-probability output.

Description

The Sampling class performs the final token selection step in the generation pipeline. It inherits from BaseGenerationParam and supports four operations: Setup, Forward, Fetch, and Update.

Setup(): Extracts per-request sampling parameters (top_k, top_p, min_p) from the batch's request caches. Computes aggregate statistics (max_topk, min_topk, min_topp, max_minp) to determine which filtering kernels are needed. Copies parameters from pinned host buffers to device memory. Also checks if any request requires log-probability output.

Forward(): Executes a multi-stage sampling pipeline:

  1. Top-K sort/filter: If any request uses top-k (max_topk > 0), invokes invokeTopKSortFilter to sort logits and retain only the top-k candidates.
  2. Top-P sort: If any request skips top-k (min_topk == 0), applies softmax followed by invokeTopPSort to sort by probability.
  3. Top-P / Min-P filter: If any request uses top-p (min_topp < 1) or min-p (max_minp > 0), applies invokeTopPMinPFilter to further filter candidates.
  4. Sampling: Calls invokeSampling to randomly select tokens from the filtered distribution using cuRAND states. Optionally outputs sampled log-probabilities, indices, and counts.

Fetch(): Copies sampled log-probability data from device to pinned host memory when log-probability output is requested.

Update(): Writes sampled log-probability values, indices, and counts into the per-request output tensors at the correct sequence position offsets.

Usage

Used within the Generation module. Called at kSetup to configure sampling parameters, kForward to perform token selection, kFetch to retrieve log-probabilities, and kUpdate to write results to request output buffers.

Code Reference

Source Location

Signature

class Sampling: public BaseGenerationParam {
public:
    explicit Sampling(const BaseGenerationParam& base, int phases);

    void Setup(int phase, TensorMap& env);

    void Forward(int phase, TensorMap& env);

    void Fetch(int phase, TensorMap& env);

    void Update(int phase, TensorMap& env);

private:
    std::vector<std::shared_ptr<SamplingData>> data_;

    // host buffers
    Buffer_<int>   kept_;
    Buffer_<int>   top_k_;
    Buffer_<float> top_p_;
    Buffer_<float> min_p_;

    Buffer_<float> sampled_logprobs_buf_;
    Buffer_<int>   sampled_indices_buf_;
    Buffer_<int>   sampled_nums_buf_;
};

Import

#include "src/turbomind/generation/sampling.h"

I/O Contract

Inputs

Name Type Required Description
base BaseGenerationParam Yes Base parameters (max_batch_size, vocab_size, vocab_size_padded)
phases int Yes Number of pipeline phases
env["batch"] BatchData* Yes (Setup/Fetch/Update) Batch data with request caches
env["copy"] BatchCopy* Yes (Setup/Fetch) Host-to-device copy utility
env["logits"] Tensor_<float> Yes (Forward) Processed logits tensor (batch_size, vocab_size_padded)
env["curand_state"] Tensor Yes (Forward) cuRAND states for random sampling
env["output_ids"] Tensor Yes (Forward) Output tensor for sampled token IDs
env["sequence_length"] Tensor Yes (Forward) Current sequence lengths (updated by sampling)

Outputs

Name Type Description
env["output_ids"] (modified) Tensor Populated with sampled token IDs
env["sequence_length"] (modified) Tensor Updated sequence lengths
Per-request logprob outputs Tensors in Request.outputs "logprob_vals", "logprob_indexes", "logprob_nums" when requested

Usage Examples

// Construction (inside Generation module)
Sampling sampler(base_param, phases);

// Setup: extract per-request top_k, top_p, min_p
sampler.Setup(phase, env);

// Forward: sort, filter, and sample tokens
sampler.Forward(phase, env);

// Fetch: copy logprobs from device to host
sampler.Fetch(phase, env);

// Update: write logprobs to per-request output buffers
sampler.Update(phase, env);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment