Implementation:Romsto Speculative Decoding Autoregressive Generate Encoder Decoder
| Knowledge Sources | |
|---|---|
| Domains | NLP, Inference, Encoder_Decoder |
| Last Updated | 2026-02-14 05:00 GMT |
Overview
Concrete tool for standard sequential token-by-token text generation for encoder-decoder models such as T5 and BART.
Description
The autoregressive_generate_encoder_decoder function implements autoregressive text generation for encoder-decoder architectures. Unlike the decoder-only variant, this function passes a fixed encoder prompt via input_ids and generates tokens on the decoder side via decoder_input_ids. The decoder sequence is initialized with the model's decoder_start_token_id from config. At each step, both the encoder prompt and the current decoder prefix are fed to the model, logits are extracted from the last decoder position, a token is sampled via the LogitsProcessor, and the loop continues until an EOS token is hit or the maximum length is reached.
This function serves as the baseline for comparing throughput against the encoder-decoder speculative decoding variant.
Usage
Import this function when you need autoregressive generation from an encoder-decoder model (e.g., T5, BART, mBART). It is the encoder-decoder counterpart to autoregressive_generate which handles decoder-only models. Use this as the baseline when benchmarking encoder-decoder speculative decoding.
Code Reference
Source Location
- Repository: Speculative-Decoding
- File: sampling/codec_base_decoding.py
- Lines: 8-73
Signature
@torch.no_grad()
def autoregressive_generate_encoder_decoder(
inputs: List[int],
model: Module,
max_gen_len: int = 40,
logits_processor: LogitsProcessor = GreedyProcessor(),
eos_tokens_id: int | List[int] = 1,
pad_token_id: int = 0,
use_cache: bool = False,
debug: bool = False,
) -> List[int]:
"""
Generate text sequence autoregressively based on the input sequence.
Args:
inputs (List[int]): input sequence of batch size 1.
model (Module): model to use for inference.
max_gen_len (int): maximum length of the generated sequence.
logits_processor (LogitsProcessor): logits processor for sampling.
eos_tokens_id (int): end token id.
pad_token_id (int): pad token id.
use_cache (bool): whether to use cache.
debug (bool): debug mode.
Returns:
List[int]: generated sequence.
Note:
This generation method only works for encoder-decoder models.
"""
Import
from sampling.codec_base_decoding import autoregressive_generate_encoder_decoder
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| inputs | List[int] | Yes | Tokenized encoder input sequence (batch size 1) |
| model | torch.nn.Module | Yes | Encoder-decoder model (e.g., T5, BART) |
| max_gen_len | int | No | Maximum decoder tokens to generate (default: 40) |
| logits_processor | LogitsProcessor | No | Sampling strategy (default: GreedyProcessor()) |
| eos_tokens_id | int or List[int] | No | End-of-sequence token ID(s) (default: 1) |
| pad_token_id | int | No | Padding token ID (default: 0) |
| use_cache | bool | No | Enable KV-cache for cross-attention and self-attention (default: False) |
| debug | bool | No | Enable debug output (default: False) |
Outputs
| Name | Type | Description |
|---|---|---|
| generated_ids | List[int] | Decoder token IDs including decoder_start_token_id up to and including EOS (or max length) |
Usage Examples
Basic Encoder-Decoder Generation
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from sampling.codec_base_decoding import autoregressive_generate_encoder_decoder
from utils.logits_processor import GreedyProcessor
# 1. Load an encoder-decoder model
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small", device_map="cuda")
model.eval()
tokenizer = AutoTokenizer.from_pretrained("t5-small")
# 2. Prepare encoder input
prompt = "translate English to French: Hello, how are you?"
inputs = tokenizer(prompt, return_tensors="pt").input_ids[0].tolist()
# 3. Generate decoder output
output_ids = autoregressive_generate_encoder_decoder(
inputs,
model,
max_gen_len=50,
logits_processor=GreedyProcessor(),
eos_tokens_id=[tokenizer.eos_token_id],
pad_token_id=tokenizer.pad_token_id,
use_cache=True,
)
# 4. Decode
print(tokenizer.decode(output_ids, skip_special_tokens=True))