Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Romsto Speculative Decoding Speculative Generate

From Leeroopedia
Knowledge Sources
Domains NLP, Inference_Optimization
Last Updated 2026-02-14 04:30 GMT

Overview

Concrete tool for accelerating LLM inference via speculative decoding with rejection sampling, provided by this repository.

Description

The speculative_generate function implements the full speculative decoding loop as described by Leviathan et al. (2022) and Chen et al. (2023). It takes a small drafter model and a large target model, generates gamma draft tokens from the drafter, verifies them against the target in a single forward pass using rejection sampling, and accepts or rejects tokens based on the probability ratio p/q. When a draft is rejected, the distribution is adjusted using the max_fn helper to sample a correction token from norm(max(0, p - q)).

The function supports optional KV-cache for faster sequential decoding, configurable sampling strategies via the LogitsProcessor interface, and a debug mode for visualizing accepted/rejected drafts.

Usage

Import this function when you need to generate text from a large language model and have a compatible smaller drafter model available. Both models must be decoder-only and share the same vocabulary size. This is the primary generation function for the standard speculative decoding workflow.

Code Reference

Source Location

Signature

@torch.no_grad()
def speculative_generate(
    inputs: List[int],
    drafter: Module,
    target: Module,
    tokenizer = None,
    gamma: int = 5,
    logits_processor: LogitsProcessor = GreedyProcessor(),
    max_gen_len: int = 40,
    eos_tokens_id: int | List[int] = 1,
    pad_token_id: int = 0,
    use_cache: bool = False,
    skip_sample_adjustment: bool = False,
    first_target: bool = True,
    debug: bool = False,
) -> Tuple[List[int], float]:
    """
    Generate text sequence using the speculative decoding algorithm.

    Args:
        inputs (List[int]): input sequence of batch size 1.
        drafter (Module): drafter model.
        target (Module): target model.
        tokenizer: tokenizer (used for debugging).
        gamma (int): number of drafts generated by the drafter at each step.
        logits_processor (LogitsProcessor): logits processor for sampling.
        max_gen_len (int): maximum length of the generated sequence.
        eos_tokens_id (int or List[int]): end token id (could be multiple).
        pad_token_id (int): pad token id.
        use_cache (bool): whether to use cache.
        skip_sample_adjustment (bool): whether to skip the max_fn adjustment step.
        first_target (bool): whether to run target model before speculative loop.
        debug (bool): debug mode.

    Returns:
        List[int]: generated token sequence.
        float: acceptance rate (accepted_drafts / total_drafts).
    """

Import

from sampling import speculative_generate

I/O Contract

Inputs

Name Type Required Description
inputs List[int] Yes Tokenized input prompt (batch size 1)
drafter torch.nn.Module Yes Small draft model (decoder-only, same vocab as target)
target torch.nn.Module Yes Large target model (decoder-only)
tokenizer PreTrainedTokenizer No Tokenizer for debug printing only
gamma int No Number of draft tokens per speculative round (default: 5)
logits_processor LogitsProcessor No Sampling strategy (default: GreedyProcessor())
max_gen_len int No Maximum new tokens to generate (default: 40)
eos_tokens_id int or List[int] No End-of-sequence token ID(s) (default: 1)
pad_token_id int No Padding token ID (default: 0)
use_cache bool No Enable KV-cache for faster decoding (default: False)
skip_sample_adjustment bool No Skip max_fn correction on rejection (default: False)
first_target bool No Run target prefill before speculative loop (default: True)
debug bool No Enable debug visualization of accepted/rejected tokens (default: False)

Outputs

Name Type Description
generated_ids List[int] Generated token IDs (excludes the prompt)
acceptance_rate float Ratio of accepted draft tokens to total speculated tokens

Usage Examples

Basic Speculative Generation

from transformers import AutoModelForCausalLM, AutoTokenizer
from sampling import speculative_generate
from utils.logits_processor import GreedyProcessor

# 1. Load target and drafter models
target = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-3B-Instruct", device_map="cuda"
)
target.eval()

drafter = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-1B-Instruct", device_map="cuda"
)
drafter.eval()

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")

# 2. Prepare input
prompt = "Explain quantum computing in simple terms."
chat = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False)
inputs = tokenizer(text, return_tensors="pt").input_ids[0].tolist()

# 3. Run speculative generation
output_ids, accept_rate = speculative_generate(
    inputs,
    drafter,
    target,
    tokenizer=tokenizer,
    gamma=5,
    logits_processor=GreedyProcessor(),
    max_gen_len=100,
    eos_tokens_id=[tokenizer.eos_token_id],
)

# 4. Decode output
output_text = tokenizer.decode(output_ids, skip_special_tokens=True)
print(f"Output: {output_text}")
print(f"Acceptance rate: {accept_rate:.3f}")

With KV-Cache and Nucleus Sampling

from utils.logits_processor import NucleusProcessor

output_ids, accept_rate = speculative_generate(
    inputs,
    drafter,
    target,
    gamma=4,
    logits_processor=NucleusProcessor(temperature=0.7, top_p=0.9),
    max_gen_len=50,
    eos_tokens_id=[tokenizer.eos_token_id],
    use_cache=True,
    first_target=True,
)

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment