Implementation:Romsto Speculative Decoding Speculative Generate Encoder Decoder
| Knowledge Sources | |
|---|---|
| Domains | NLP, Inference_Optimization, Encoder_Decoder |
| Last Updated | 2026-02-14 05:00 GMT |
Overview
Concrete tool for accelerating encoder-decoder model inference via speculative decoding with rejection sampling, using a smaller drafter encoder-decoder model.
Description
The speculative_generate_encoder_decoder function implements the speculative decoding algorithm adapted for encoder-decoder architectures (e.g., T5, BART). Both the drafter and target models receive the same encoder input (input_ids), and all drafting and verification occurs on the decoder side via decoder_input_ids.
The drafter generates gamma draft decoder tokens sequentially, storing probability distributions for each position. The target model then verifies all drafts in a single forward pass. Acceptance uses rejection sampling: for each draft position, a random value r is compared against p(x)/q(x) (target probability over drafter probability for the drafted token). The first position where r > p/q triggers rejection.
When drafts are rejected and skip_sample_adjustment is False, the next token is sampled from max_fn(p - q) (the normalized positive part of the target-minus-drafter distribution difference) to preserve the target model's output distribution. KV-cache pruning is applied on rejection via prune_cache.
The function also supports an optional first_target prefill step, which runs the target model once before the speculative loop to seed the decoder with a first token and warm the KV-cache.
A local max_fn helper is included in this file, identical to the one in speculative_decoding.py, computing norm(max(0, x)) for the adjusted distribution.
Usage
Import this function when you need to accelerate generation from an encoder-decoder model and have a compatible smaller encoder-decoder drafter. Both models must share the same vocabulary size. This is the encoder-decoder counterpart of speculative_generate which handles decoder-only models. N-gram models are not supported in this variant.
Code Reference
Source Location
- Repository: Speculative-Decoding
- File: sampling/codec_speculative_decoding.py
- Lines: 1-193
Signature
def max_fn(x: torch.Tensor) -> torch.Tensor:
"""
Max function.
x: input tensor.
Returns:
tensor norm(max(0, x)).
"""
@torch.no_grad()
def speculative_generate_encoder_decoder(
inputs: List[int],
drafter: Module,
target: Module,
tokenizer = None,
gamma: int = 5,
logits_processor: LogitsProcessor = GreedyProcessor(),
max_gen_len: int = 40,
eos_tokens_id: int | List[int] = 1,
pad_token_id: int = 0,
use_cache: bool = False,
skip_sample_adjustment: bool = False,
first_target: bool = True,
debug: bool = False,
) -> Tuple[List[int], float]:
"""
Generate text sequence using speculative decoding for encoder-decoder models.
Args:
inputs (List[int]): input sequence of batch size 1.
drafter (Module): drafter encoder-decoder model.
target (Module): target encoder-decoder model.
tokenizer: tokenizer (used for debugging).
gamma (int): number of drafts generated by the drafter at each step.
logits_processor (LogitsProcessor): logits processor for sampling.
max_gen_len (int): maximum length of the generated sequence.
eos_tokens_id (int or List[int]): end token id (could be multiple).
pad_token_id (int): pad token id.
use_cache (bool): whether to use cache.
skip_sample_adjustment (bool): skip the max_fn adjustment step.
first_target (bool): run target model before the speculative loop.
debug (bool): debug mode.
Returns:
List[int]: generated sequence.
float: acceptance rate (accepted drafts / total drafts).
"""
Import
from sampling.codec_speculative_decoding import speculative_generate_encoder_decoder
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| inputs | List[int] | Yes | Tokenized encoder input sequence (batch size 1) |
| drafter | torch.nn.Module | Yes | Small encoder-decoder drafter model (same vocab as target) |
| target | torch.nn.Module | Yes | Large encoder-decoder target model |
| tokenizer | PreTrainedTokenizer | No | Tokenizer for debug printing only |
| gamma | int | No | Number of draft decoder tokens per speculative round (default: 5) |
| logits_processor | LogitsProcessor | No | Sampling strategy (default: GreedyProcessor()) |
| max_gen_len | int | No | Maximum decoder tokens to generate (default: 40) |
| eos_tokens_id | int or List[int] | No | End-of-sequence token ID(s) (default: 1) |
| pad_token_id | int | No | Padding token ID (default: 0) |
| use_cache | bool | No | Enable KV-cache for cross and self-attention (default: False) |
| skip_sample_adjustment | bool | No | Skip max_fn correction on rejection (default: False) |
| first_target | bool | No | Run target prefill before speculative loop (default: True) |
| debug | bool | No | Enable debug visualization of accepted/rejected tokens (default: False) |
Outputs
| Name | Type | Description |
|---|---|---|
| generated_ids | List[int] | Decoder token IDs including decoder_start_token_id through generated tokens |
| acceptance_rate | float | Ratio of accepted draft tokens to total speculated tokens |
Usage Examples
Basic Encoder-Decoder Speculative Generation
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from sampling.codec_speculative_decoding import speculative_generate_encoder_decoder
from utils.logits_processor import GreedyProcessor
# 1. Load target and drafter encoder-decoder models
target = AutoModelForSeq2SeqLM.from_pretrained("t5-large", device_map="cuda")
target.eval()
drafter = AutoModelForSeq2SeqLM.from_pretrained("t5-small", device_map="cuda")
drafter.eval()
tokenizer = AutoTokenizer.from_pretrained("t5-large")
# 2. Prepare encoder input
prompt = "translate English to French: The weather is beautiful today."
inputs = tokenizer(prompt, return_tensors="pt").input_ids[0].tolist()
# 3. Run speculative generation
output_ids, accept_rate = speculative_generate_encoder_decoder(
inputs,
drafter,
target,
tokenizer=tokenizer,
gamma=5,
logits_processor=GreedyProcessor(),
max_gen_len=50,
eos_tokens_id=[tokenizer.eos_token_id],
pad_token_id=tokenizer.pad_token_id,
use_cache=True,
)
# 4. Decode output
output_text = tokenizer.decode(output_ids, skip_special_tokens=True)
print(f"Output: {output_text}")
print(f"Acceptance rate: {accept_rate:.3f}")
With Sample Adjustment Disabled
from utils.logits_processor import NucleusProcessor
output_ids, accept_rate = speculative_generate_encoder_decoder(
inputs,
drafter,
target,
gamma=4,
logits_processor=NucleusProcessor(temperature=0.7, top_p=0.9),
max_gen_len=100,
eos_tokens_id=[tokenizer.eos_token_id],
pad_token_id=tokenizer.pad_token_id,
use_cache=True,
skip_sample_adjustment=True, # Faster but does not preserve exact target distribution
first_target=True,
)