Principle:Romsto Speculative Decoding Rejection Sampling Adjustment
| Knowledge Sources | |
|---|---|
| Domains | Probability_Theory, Inference_Optimization, Sampling |
| Last Updated | 2026-02-14 04:30 GMT |
Overview
A distribution correction mechanism that normalizes the positive part of the difference between target and drafter probability distributions to preserve the exact target distribution upon draft rejection.
Description
Rejection Sampling Adjustment (also called the max function or residual distribution) is the mathematical correction applied when a speculative decoding draft token is rejected. When the drafter model proposes a token that fails the acceptance test (i.e., a uniform random sample exceeds the probability ratio p/q), the replacement token must be sampled from a carefully constructed distribution rather than simply from the target distribution.
The correction ensures that the marginal distribution over tokens at any position exactly matches what the target model would produce under standard autoregressive generation. Without this correction, rejected positions would be biased toward tokens where the target distribution exceeds the drafter distribution.
The adjusted distribution is computed as: normalize the element-wise maximum of zero and (p - q), where p is the target distribution and q is the drafter distribution at the rejected position.
Usage
Use this principle whenever implementing the rejection step of speculative decoding. It is essential for maintaining the theoretical guarantee that speculative decoding produces an identical output distribution to the target model. Skipping this adjustment (as allowed by the skip_sample_adjustment flag) sacrifices distributional correctness for slightly simpler computation.
Theoretical Basis
When a draft token at position n is rejected, the replacement token is sampled from the adjusted distribution:
Where:
- is the target model's probability for token x at position n
- is the drafter model's probability for token x at position n
Pseudo-code:
# Abstract rejection sampling adjustment
def adjusted_distribution(p, q):
"""Compute the adjusted distribution for rejection sampling."""
diff = p - q
positive_part = max(0, diff) # element-wise
return positive_part / sum(positive_part) # normalize
Correctness proof sketch: The probability of accepting a token x from the drafter and it being x is . The probability of rejecting and then sampling x from the adjusted distribution fills in the remaining probability mass. Together, the total probability of outputting x equals .