Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:LaurentMazare Tch rs Tensor Multinomial

From Leeroopedia


Knowledge Sources
Domains NLP, Text_Generation
Last Updated 2026-02-08 14:00 GMT

Overview

Concrete tool for sampling token indices from a probability distribution provided by the tch tensor operations.

Description

Tensor::multinomial samples indices from a multinomial probability distribution. In the context of text generation, it takes the softmax probabilities over the vocabulary and returns a sampled token index. The method is a generated binding to libtorch's torch::multinomial function.

Usage

Use in the autoregressive generation loop after applying temperature scaling and softmax to model logits. Sample 1 token per step with replacement=true.

Code Reference

Source Location

  • Repository: tch-rs
  • File: src/wrappers/tensor_generated.rs (generated binding)

Signature

impl Tensor {
    pub fn multinomial(&self, num_samples: i64, replacement: bool) -> Tensor
}

Import

use tch::Tensor;

I/O Contract

Inputs

Name Type Required Description
self &Tensor Yes Probability distribution (softmax output), shape [vocab_size] or [1, vocab_size]
num_samples i64 Yes Number of samples to draw (1 for token generation)
replacement bool Yes Whether to sample with replacement (typically true)

Outputs

Name Type Description
Tensor Tensor Sampled indices, shape [num_samples]

Usage Examples

use tch::{Tensor, Kind};

// In generation loop:
let _no_grad = tch::no_grad_guard();

let logits = model.forward(&tokens, &freqs_cis);
let probs = (logits / temperature).softmax(-1, Kind::Float);
let next_token = probs.multinomial(1, true);  // Sample one token

// Append to sequence
tokens = Tensor::cat(&[tokens, next_token.view([1, 1])], 1);

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment