Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Romsto Speculative Decoding Max Fn

From Leeroopedia
Knowledge Sources
Domains Probability_Theory, Inference_Optimization
Last Updated 2026-02-14 04:30 GMT

Overview

Concrete tool for computing the normalized positive-part residual distribution used in speculative decoding rejection sampling.

Description

The max_fn function computes norm(max(0,x)) where x is typically the difference between target and drafter probability distributions (p - q). This produces the adjusted distribution from which a replacement token is sampled when a speculative draft is rejected. The function clamps all negative values to zero and then normalizes the result to form a valid probability distribution.

Usage

Import this function when implementing the rejection step of speculative decoding. It is called internally by speculative_generate when a draft token fails the acceptance test and skip_sample_adjustment is False. It is not needed for NASD (n-gram assisted speculative decoding), which uses greedy matching instead of rejection sampling.

Code Reference

Source Location

Signature

def max_fn(x: torch.Tensor) -> torch.Tensor:
    """
    Max function.
        x: input tensor.
    Returns:
        tensor norm(max(0, x)).
    """

Import

from sampling.speculative_decoding import max_fn

I/O Contract

Inputs

Name Type Required Description
x torch.Tensor Yes Difference distribution (p - q), shape (..., vocab_size)

Outputs

Name Type Description
result torch.Tensor Normalized positive-part distribution, same shape as input. All negative values clamped to 0, then L1-normalized along the last dimension.

Usage Examples

Direct Usage

import torch
from sampling.speculative_decoding import max_fn

# Suppose p and q are probability distributions over vocabulary
p = torch.tensor([0.4, 0.3, 0.2, 0.1])  # target distribution
q = torch.tensor([0.1, 0.5, 0.3, 0.1])  # drafter distribution

# Compute adjusted distribution for rejection sampling
adjusted = max_fn(p - q)
# Result: norm(max(0, [0.3, -0.2, -0.1, 0.0])) = norm([0.3, 0, 0, 0]) = [1.0, 0, 0, 0]
print(adjusted)  # tensor([1.0000, 0.0000, 0.0000, 0.0000])

Within Speculative Decoding Context

# Called internally by speculative_generate when draft n is rejected:
# p[..., n, :] is the target distribution at rejected position
# q[0, n, :] is the drafter distribution at rejected position
p_p = max_fn(p[..., n, :] - q[0, n, :])
x = logits_processor.sample(p_p)  # sample replacement token

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment