Implementation:InternLM Lmdeploy Sampling
| Knowledge Sources | |
|---|---|
| Domains | Text Generation, Token Sampling |
| Last Updated | 2026-02-07 15:00 GMT |
Overview
Implements the token sampling module that selects output tokens from processed logits using top-k, top-p, and min-p filtering strategies with optional log-probability output.
Description
The Sampling class performs the final token selection step in the generation pipeline. It inherits from BaseGenerationParam and supports four operations: Setup, Forward, Fetch, and Update.
Setup(): Extracts per-request sampling parameters (top_k, top_p, min_p) from the batch's request caches. Computes aggregate statistics (max_topk, min_topk, min_topp, max_minp) to determine which filtering kernels are needed. Copies parameters from pinned host buffers to device memory. Also checks if any request requires log-probability output.
Forward(): Executes a multi-stage sampling pipeline:
- Top-K sort/filter: If any request uses top-k (max_topk > 0), invokes
invokeTopKSortFilterto sort logits and retain only the top-k candidates. - Top-P sort: If any request skips top-k (min_topk == 0), applies softmax followed by
invokeTopPSortto sort by probability. - Top-P / Min-P filter: If any request uses top-p (min_topp < 1) or min-p (max_minp > 0), applies
invokeTopPMinPFilterto further filter candidates. - Sampling: Calls
invokeSamplingto randomly select tokens from the filtered distribution using cuRAND states. Optionally outputs sampled log-probabilities, indices, and counts.
Fetch(): Copies sampled log-probability data from device to pinned host memory when log-probability output is requested.
Update(): Writes sampled log-probability values, indices, and counts into the per-request output tensors at the correct sequence position offsets.
Usage
Used within the Generation module. Called at kSetup to configure sampling parameters, kForward to perform token selection, kFetch to retrieve log-probabilities, and kUpdate to write results to request output buffers.
Code Reference
Source Location
- Repository: InternLM_Lmdeploy
- File: src/turbomind/generation/sampling.h
- File: src/turbomind/generation/sampling.cc
- Lines: sampling.h 1-37, sampling.cc 1-242
Signature
class Sampling: public BaseGenerationParam {
public:
explicit Sampling(const BaseGenerationParam& base, int phases);
void Setup(int phase, TensorMap& env);
void Forward(int phase, TensorMap& env);
void Fetch(int phase, TensorMap& env);
void Update(int phase, TensorMap& env);
private:
std::vector<std::shared_ptr<SamplingData>> data_;
// host buffers
Buffer_<int> kept_;
Buffer_<int> top_k_;
Buffer_<float> top_p_;
Buffer_<float> min_p_;
Buffer_<float> sampled_logprobs_buf_;
Buffer_<int> sampled_indices_buf_;
Buffer_<int> sampled_nums_buf_;
};
Import
#include "src/turbomind/generation/sampling.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| base | BaseGenerationParam | Yes | Base parameters (max_batch_size, vocab_size, vocab_size_padded) |
| phases | int | Yes | Number of pipeline phases |
| env["batch"] | BatchData* | Yes (Setup/Fetch/Update) | Batch data with request caches |
| env["copy"] | BatchCopy* | Yes (Setup/Fetch) | Host-to-device copy utility |
| env["logits"] | Tensor_<float> | Yes (Forward) | Processed logits tensor (batch_size, vocab_size_padded) |
| env["curand_state"] | Tensor | Yes (Forward) | cuRAND states for random sampling |
| env["output_ids"] | Tensor | Yes (Forward) | Output tensor for sampled token IDs |
| env["sequence_length"] | Tensor | Yes (Forward) | Current sequence lengths (updated by sampling) |
Outputs
| Name | Type | Description |
|---|---|---|
| env["output_ids"] (modified) | Tensor | Populated with sampled token IDs |
| env["sequence_length"] (modified) | Tensor | Updated sequence lengths |
| Per-request logprob outputs | Tensors in Request.outputs | "logprob_vals", "logprob_indexes", "logprob_nums" when requested |
Usage Examples
// Construction (inside Generation module)
Sampling sampler(base_param, phases);
// Setup: extract per-request top_k, top_p, min_p
sampler.Setup(phase, env);
// Forward: sort, filter, and sample tokens
sampler.Forward(phase, env);
// Fetch: copy logprobs from device to host
sampler.Fetch(phase, env);
// Update: write logprobs to per-request output buffers
sampler.Update(phase, env);