Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Romsto Speculative Decoding Speculative Generate Encoder Decoder

From Leeroopedia
Knowledge Sources
Domains NLP, Inference_Optimization, Encoder_Decoder
Last Updated 2026-02-14 05:00 GMT

Overview

Concrete tool for accelerating encoder-decoder model inference via speculative decoding with rejection sampling, using a smaller drafter encoder-decoder model.

Description

The speculative_generate_encoder_decoder function implements the speculative decoding algorithm adapted for encoder-decoder architectures (e.g., T5, BART). Both the drafter and target models receive the same encoder input (input_ids), and all drafting and verification occurs on the decoder side via decoder_input_ids.

The drafter generates gamma draft decoder tokens sequentially, storing probability distributions for each position. The target model then verifies all drafts in a single forward pass. Acceptance uses rejection sampling: for each draft position, a random value r is compared against p(x)/q(x) (target probability over drafter probability for the drafted token). The first position where r > p/q triggers rejection.

When drafts are rejected and skip_sample_adjustment is False, the next token is sampled from max_fn(p - q) (the normalized positive part of the target-minus-drafter distribution difference) to preserve the target model's output distribution. KV-cache pruning is applied on rejection via prune_cache.

The function also supports an optional first_target prefill step, which runs the target model once before the speculative loop to seed the decoder with a first token and warm the KV-cache.

A local max_fn helper is included in this file, identical to the one in speculative_decoding.py, computing norm(max(0, x)) for the adjusted distribution.

Usage

Import this function when you need to accelerate generation from an encoder-decoder model and have a compatible smaller encoder-decoder drafter. Both models must share the same vocabulary size. This is the encoder-decoder counterpart of speculative_generate which handles decoder-only models. N-gram models are not supported in this variant.

Code Reference

Source Location

Signature

def max_fn(x: torch.Tensor) -> torch.Tensor:
    """
    Max function.
        x: input tensor.
    Returns:
        tensor norm(max(0, x)).
    """

@torch.no_grad()
def speculative_generate_encoder_decoder(
    inputs: List[int],
    drafter: Module,
    target: Module,
    tokenizer = None,
    gamma: int = 5,
    logits_processor: LogitsProcessor = GreedyProcessor(),
    max_gen_len: int = 40,
    eos_tokens_id: int | List[int] = 1,
    pad_token_id: int = 0,
    use_cache: bool = False,
    skip_sample_adjustment: bool = False,
    first_target: bool = True,
    debug: bool = False,
) -> Tuple[List[int], float]:
    """
    Generate text sequence using speculative decoding for encoder-decoder models.

    Args:
        inputs (List[int]): input sequence of batch size 1.
        drafter (Module): drafter encoder-decoder model.
        target (Module): target encoder-decoder model.
        tokenizer: tokenizer (used for debugging).
        gamma (int): number of drafts generated by the drafter at each step.
        logits_processor (LogitsProcessor): logits processor for sampling.
        max_gen_len (int): maximum length of the generated sequence.
        eos_tokens_id (int or List[int]): end token id (could be multiple).
        pad_token_id (int): pad token id.
        use_cache (bool): whether to use cache.
        skip_sample_adjustment (bool): skip the max_fn adjustment step.
        first_target (bool): run target model before the speculative loop.
        debug (bool): debug mode.

    Returns:
        List[int]: generated sequence.
        float: acceptance rate (accepted drafts / total drafts).
    """

Import

from sampling.codec_speculative_decoding import speculative_generate_encoder_decoder

I/O Contract

Inputs

Name Type Required Description
inputs List[int] Yes Tokenized encoder input sequence (batch size 1)
drafter torch.nn.Module Yes Small encoder-decoder drafter model (same vocab as target)
target torch.nn.Module Yes Large encoder-decoder target model
tokenizer PreTrainedTokenizer No Tokenizer for debug printing only
gamma int No Number of draft decoder tokens per speculative round (default: 5)
logits_processor LogitsProcessor No Sampling strategy (default: GreedyProcessor())
max_gen_len int No Maximum decoder tokens to generate (default: 40)
eos_tokens_id int or List[int] No End-of-sequence token ID(s) (default: 1)
pad_token_id int No Padding token ID (default: 0)
use_cache bool No Enable KV-cache for cross and self-attention (default: False)
skip_sample_adjustment bool No Skip max_fn correction on rejection (default: False)
first_target bool No Run target prefill before speculative loop (default: True)
debug bool No Enable debug visualization of accepted/rejected tokens (default: False)

Outputs

Name Type Description
generated_ids List[int] Decoder token IDs including decoder_start_token_id through generated tokens
acceptance_rate float Ratio of accepted draft tokens to total speculated tokens

Usage Examples

Basic Encoder-Decoder Speculative Generation

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from sampling.codec_speculative_decoding import speculative_generate_encoder_decoder
from utils.logits_processor import GreedyProcessor

# 1. Load target and drafter encoder-decoder models
target = AutoModelForSeq2SeqLM.from_pretrained("t5-large", device_map="cuda")
target.eval()

drafter = AutoModelForSeq2SeqLM.from_pretrained("t5-small", device_map="cuda")
drafter.eval()

tokenizer = AutoTokenizer.from_pretrained("t5-large")

# 2. Prepare encoder input
prompt = "translate English to French: The weather is beautiful today."
inputs = tokenizer(prompt, return_tensors="pt").input_ids[0].tolist()

# 3. Run speculative generation
output_ids, accept_rate = speculative_generate_encoder_decoder(
    inputs,
    drafter,
    target,
    tokenizer=tokenizer,
    gamma=5,
    logits_processor=GreedyProcessor(),
    max_gen_len=50,
    eos_tokens_id=[tokenizer.eos_token_id],
    pad_token_id=tokenizer.pad_token_id,
    use_cache=True,
)

# 4. Decode output
output_text = tokenizer.decode(output_ids, skip_special_tokens=True)
print(f"Output: {output_text}")
print(f"Acceptance rate: {accept_rate:.3f}")

With Sample Adjustment Disabled

from utils.logits_processor import NucleusProcessor

output_ids, accept_rate = speculative_generate_encoder_decoder(
    inputs,
    drafter,
    target,
    gamma=4,
    logits_processor=NucleusProcessor(temperature=0.7, top_p=0.9),
    max_gen_len=100,
    eos_tokens_id=[tokenizer.eos_token_id],
    pad_token_id=tokenizer.pad_token_id,
    use_cache=True,
    skip_sample_adjustment=True,  # Faster but does not preserve exact target distribution
    first_target=True,
)

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment