Implementation:Romsto Speculative Decoding LogitsProcessor Hierarchy
| Knowledge Sources | |
|---|---|
| Domains | NLP, Sampling |
| Last Updated | 2026-02-14 04:30 GMT |
Overview
Concrete tool for converting model logits to probability distributions and sampling tokens, providing five interchangeable sampling strategies via the Strategy pattern.
Description
The LogitsProcessor hierarchy implements sampling strategies as an abstract base class with concrete subclasses. The ABC defines the interface: __call__(logits) applies temperature-scaled softmax after an optional filtering step (_process), and sample(probs) selects a token from the resulting distribution.
The class hierarchy is:
- LogitsProcessor (ABC) — defines __call__, _process, sample
- GreedyProcessor — argmax sampling, no filtering
- MultinomialProcessor — torch.multinomial sampling, no filtering
- TopKProcessor — top-k filtering + multinomial
- NucleusProcessor — top-p filtering + multinomial
- TopKNucleusProcessor — top-k then top-p filtering + multinomial
Usage
Import the desired processor class when configuring sampling for any generation function in this repository (speculative_generate, ngram_assisted_speculative_generate, autoregressive_generate). The same processor instance is used for both probability computation and token sampling throughout a generation run.
Code Reference
Source Location
- Repository: Speculative-Decoding
- File: utils/logits_processor.py
- Lines: L7-103
Signature
class LogitsProcessor(abc.ABC):
"""Logits processors for sampling."""
def __init__(self, temperature: float):
self.temperature = temperature
def __call__(self, logits: Tensor) -> Tensor:
proc = self._process(logits)
return F.softmax(proc / self.temperature, dim=-1)
@abc.abstractmethod
def _process(self, logits: Tensor) -> Tensor:
pass
@abc.abstractmethod
def sample(self, probs: Tensor) -> Tensor:
pass
class GreedyProcessor(LogitsProcessor):
def __init__(self, temperature: float = 1): ...
def _process(self, logits: Tensor) -> Tensor: ...
def sample(self, probs: Tensor) -> Tensor: ... # argmax
class MultinomialProcessor(LogitsProcessor):
def __init__(self, temperature: float): ...
def sample(self, probs: Tensor) -> Tensor: ... # multinomial
class TopKProcessor(MultinomialProcessor):
def __init__(self, temperature: float, top_k: int): ...
def _process(self, logits: Tensor) -> Tensor: ... # top-k filter
class NucleusProcessor(MultinomialProcessor):
def __init__(self, temperature: float, top_p: float): ...
def _process(self, logits: Tensor) -> Tensor: ... # nucleus filter
class TopKNucleusProcessor(MultinomialProcessor):
def __init__(self, temperature: float, top_k: int, top_p: float): ...
def _process(self, logits: Tensor) -> Tensor: ... # top-k then nucleus
Import
from utils.logits_processor import (
LogitsProcessor,
GreedyProcessor,
MultinomialProcessor,
TopKProcessor,
NucleusProcessor,
TopKNucleusProcessor,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| logits | torch.Tensor | Yes | Raw model output logits, shape (..., vocab_size) |
| temperature | float | Yes (constructor) | Softmax temperature. Higher = more diverse. Default: 1.0 for Greedy. |
| top_k | int | TopK/TopKNucleus only | Number of top tokens to retain before sampling |
| top_p | float | Nucleus/TopKNucleus only | Cumulative probability threshold for nucleus filtering |
Outputs
| Name | Type | Description |
|---|---|---|
| __call__ returns | torch.Tensor | Probability distribution, shape (..., vocab_size), sums to 1 along last dim |
| sample returns | torch.Tensor | Selected token ID(s), shape (..., 1) for multinomial or (...,) + unsqueeze for greedy |
Usage Examples
Greedy Decoding
from utils.logits_processor import GreedyProcessor
processor = GreedyProcessor() # temperature=1 by default
probs = processor(logits) # softmax(logits / 1.0)
token = processor.sample(probs) # argmax
Nucleus Sampling
from utils.logits_processor import NucleusProcessor
processor = NucleusProcessor(temperature=0.7, top_p=0.9)
probs = processor(logits) # filter to top-p, then softmax
token = processor.sample(probs) # multinomial sampling
Top-K + Nucleus Combined
from utils.logits_processor import TopKNucleusProcessor
processor = TopKNucleusProcessor(temperature=0.8, top_k=50, top_p=0.95)
probs = processor(logits) # top-k filter, then nucleus filter, then softmax
token = processor.sample(probs) # multinomial sampling