Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Romsto Speculative Decoding LogitsProcessor Hierarchy

From Leeroopedia
Knowledge Sources
Domains NLP, Sampling
Last Updated 2026-02-14 04:30 GMT

Overview

Concrete tool for converting model logits to probability distributions and sampling tokens, providing five interchangeable sampling strategies via the Strategy pattern.

Description

The LogitsProcessor hierarchy implements sampling strategies as an abstract base class with concrete subclasses. The ABC defines the interface: __call__(logits) applies temperature-scaled softmax after an optional filtering step (_process), and sample(probs) selects a token from the resulting distribution.

The class hierarchy is:

  • LogitsProcessor (ABC) — defines __call__, _process, sample
    • GreedyProcessor — argmax sampling, no filtering
    • MultinomialProcessor — torch.multinomial sampling, no filtering
      • TopKProcessor — top-k filtering + multinomial
      • NucleusProcessor — top-p filtering + multinomial
      • TopKNucleusProcessor — top-k then top-p filtering + multinomial

Usage

Import the desired processor class when configuring sampling for any generation function in this repository (speculative_generate, ngram_assisted_speculative_generate, autoregressive_generate). The same processor instance is used for both probability computation and token sampling throughout a generation run.

Code Reference

Source Location

Signature

class LogitsProcessor(abc.ABC):
    """Logits processors for sampling."""

    def __init__(self, temperature: float):
        self.temperature = temperature

    def __call__(self, logits: Tensor) -> Tensor:
        proc = self._process(logits)
        return F.softmax(proc / self.temperature, dim=-1)

    @abc.abstractmethod
    def _process(self, logits: Tensor) -> Tensor:
        pass

    @abc.abstractmethod
    def sample(self, probs: Tensor) -> Tensor:
        pass


class GreedyProcessor(LogitsProcessor):
    def __init__(self, temperature: float = 1): ...
    def _process(self, logits: Tensor) -> Tensor: ...
    def sample(self, probs: Tensor) -> Tensor: ...  # argmax


class MultinomialProcessor(LogitsProcessor):
    def __init__(self, temperature: float): ...
    def sample(self, probs: Tensor) -> Tensor: ...  # multinomial


class TopKProcessor(MultinomialProcessor):
    def __init__(self, temperature: float, top_k: int): ...
    def _process(self, logits: Tensor) -> Tensor: ...  # top-k filter


class NucleusProcessor(MultinomialProcessor):
    def __init__(self, temperature: float, top_p: float): ...
    def _process(self, logits: Tensor) -> Tensor: ...  # nucleus filter


class TopKNucleusProcessor(MultinomialProcessor):
    def __init__(self, temperature: float, top_k: int, top_p: float): ...
    def _process(self, logits: Tensor) -> Tensor: ...  # top-k then nucleus

Import

from utils.logits_processor import (
    LogitsProcessor,
    GreedyProcessor,
    MultinomialProcessor,
    TopKProcessor,
    NucleusProcessor,
    TopKNucleusProcessor,
)

I/O Contract

Inputs

Name Type Required Description
logits torch.Tensor Yes Raw model output logits, shape (..., vocab_size)
temperature float Yes (constructor) Softmax temperature. Higher = more diverse. Default: 1.0 for Greedy.
top_k int TopK/TopKNucleus only Number of top tokens to retain before sampling
top_p float Nucleus/TopKNucleus only Cumulative probability threshold for nucleus filtering

Outputs

Name Type Description
__call__ returns torch.Tensor Probability distribution, shape (..., vocab_size), sums to 1 along last dim
sample returns torch.Tensor Selected token ID(s), shape (..., 1) for multinomial or (...,) + unsqueeze for greedy

Usage Examples

Greedy Decoding

from utils.logits_processor import GreedyProcessor

processor = GreedyProcessor()  # temperature=1 by default
probs = processor(logits)      # softmax(logits / 1.0)
token = processor.sample(probs) # argmax

Nucleus Sampling

from utils.logits_processor import NucleusProcessor

processor = NucleusProcessor(temperature=0.7, top_p=0.9)
probs = processor(logits)       # filter to top-p, then softmax
token = processor.sample(probs)  # multinomial sampling

Top-K + Nucleus Combined

from utils.logits_processor import TopKNucleusProcessor

processor = TopKNucleusProcessor(temperature=0.8, top_k=50, top_p=0.95)
probs = processor(logits)       # top-k filter, then nucleus filter, then softmax
token = processor.sample(probs)  # multinomial sampling

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment