Implementation:Romsto Speculative Decoding Max Fn
| Knowledge Sources | |
|---|---|
| Domains | Probability_Theory, Inference_Optimization |
| Last Updated | 2026-02-14 04:30 GMT |
Overview
Concrete tool for computing the normalized positive-part residual distribution used in speculative decoding rejection sampling.
Description
The max_fn function computes where x is typically the difference between target and drafter probability distributions (p - q). This produces the adjusted distribution from which a replacement token is sampled when a speculative draft is rejected. The function clamps all negative values to zero and then normalizes the result to form a valid probability distribution.
Usage
Import this function when implementing the rejection step of speculative decoding. It is called internally by speculative_generate when a draft token fails the acceptance test and skip_sample_adjustment is False. It is not needed for NASD (n-gram assisted speculative decoding), which uses greedy matching instead of rejection sampling.
Code Reference
Source Location
- Repository: Speculative-Decoding
- File: sampling/speculative_decoding.py
- Lines: L10-19
Signature
def max_fn(x: torch.Tensor) -> torch.Tensor:
"""
Max function.
x: input tensor.
Returns:
tensor norm(max(0, x)).
"""
Import
from sampling.speculative_decoding import max_fn
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| x | torch.Tensor | Yes | Difference distribution (p - q), shape (..., vocab_size) |
Outputs
| Name | Type | Description |
|---|---|---|
| result | torch.Tensor | Normalized positive-part distribution, same shape as input. All negative values clamped to 0, then L1-normalized along the last dimension. |
Usage Examples
Direct Usage
import torch
from sampling.speculative_decoding import max_fn
# Suppose p and q are probability distributions over vocabulary
p = torch.tensor([0.4, 0.3, 0.2, 0.1]) # target distribution
q = torch.tensor([0.1, 0.5, 0.3, 0.1]) # drafter distribution
# Compute adjusted distribution for rejection sampling
adjusted = max_fn(p - q)
# Result: norm(max(0, [0.3, -0.2, -0.1, 0.0])) = norm([0.3, 0, 0, 0]) = [1.0, 0, 0, 0]
print(adjusted) # tensor([1.0000, 0.0000, 0.0000, 0.0000])
Within Speculative Decoding Context
# Called internally by speculative_generate when draft n is rejected:
# p[..., n, :] is the target distribution at rejected position
# q[0, n, :] is the drafter distribution at rejected position
p_p = max_fn(p[..., n, :] - q[0, n, :])
x = logits_processor.sample(p_p) # sample replacement token