Implementation:Romsto Speculative Decoding Speculative Generate
| Knowledge Sources | |
|---|---|
| Domains | NLP, Inference_Optimization |
| Last Updated | 2026-02-14 04:30 GMT |
Overview
Concrete tool for accelerating LLM inference via speculative decoding with rejection sampling, provided by this repository.
Description
The speculative_generate function implements the full speculative decoding loop as described by Leviathan et al. (2022) and Chen et al. (2023). It takes a small drafter model and a large target model, generates gamma draft tokens from the drafter, verifies them against the target in a single forward pass using rejection sampling, and accepts or rejects tokens based on the probability ratio p/q. When a draft is rejected, the distribution is adjusted using the max_fn helper to sample a correction token from norm(max(0, p - q)).
The function supports optional KV-cache for faster sequential decoding, configurable sampling strategies via the LogitsProcessor interface, and a debug mode for visualizing accepted/rejected drafts.
Usage
Import this function when you need to generate text from a large language model and have a compatible smaller drafter model available. Both models must be decoder-only and share the same vocabulary size. This is the primary generation function for the standard speculative decoding workflow.
Code Reference
Source Location
- Repository: Speculative-Decoding
- File: sampling/speculative_decoding.py
- Lines: L22-189
Signature
@torch.no_grad()
def speculative_generate(
inputs: List[int],
drafter: Module,
target: Module,
tokenizer = None,
gamma: int = 5,
logits_processor: LogitsProcessor = GreedyProcessor(),
max_gen_len: int = 40,
eos_tokens_id: int | List[int] = 1,
pad_token_id: int = 0,
use_cache: bool = False,
skip_sample_adjustment: bool = False,
first_target: bool = True,
debug: bool = False,
) -> Tuple[List[int], float]:
"""
Generate text sequence using the speculative decoding algorithm.
Args:
inputs (List[int]): input sequence of batch size 1.
drafter (Module): drafter model.
target (Module): target model.
tokenizer: tokenizer (used for debugging).
gamma (int): number of drafts generated by the drafter at each step.
logits_processor (LogitsProcessor): logits processor for sampling.
max_gen_len (int): maximum length of the generated sequence.
eos_tokens_id (int or List[int]): end token id (could be multiple).
pad_token_id (int): pad token id.
use_cache (bool): whether to use cache.
skip_sample_adjustment (bool): whether to skip the max_fn adjustment step.
first_target (bool): whether to run target model before speculative loop.
debug (bool): debug mode.
Returns:
List[int]: generated token sequence.
float: acceptance rate (accepted_drafts / total_drafts).
"""
Import
from sampling import speculative_generate
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| inputs | List[int] | Yes | Tokenized input prompt (batch size 1) |
| drafter | torch.nn.Module | Yes | Small draft model (decoder-only, same vocab as target) |
| target | torch.nn.Module | Yes | Large target model (decoder-only) |
| tokenizer | PreTrainedTokenizer | No | Tokenizer for debug printing only |
| gamma | int | No | Number of draft tokens per speculative round (default: 5) |
| logits_processor | LogitsProcessor | No | Sampling strategy (default: GreedyProcessor()) |
| max_gen_len | int | No | Maximum new tokens to generate (default: 40) |
| eos_tokens_id | int or List[int] | No | End-of-sequence token ID(s) (default: 1) |
| pad_token_id | int | No | Padding token ID (default: 0) |
| use_cache | bool | No | Enable KV-cache for faster decoding (default: False) |
| skip_sample_adjustment | bool | No | Skip max_fn correction on rejection (default: False) |
| first_target | bool | No | Run target prefill before speculative loop (default: True) |
| debug | bool | No | Enable debug visualization of accepted/rejected tokens (default: False) |
Outputs
| Name | Type | Description |
|---|---|---|
| generated_ids | List[int] | Generated token IDs (excludes the prompt) |
| acceptance_rate | float | Ratio of accepted draft tokens to total speculated tokens |
Usage Examples
Basic Speculative Generation
from transformers import AutoModelForCausalLM, AutoTokenizer
from sampling import speculative_generate
from utils.logits_processor import GreedyProcessor
# 1. Load target and drafter models
target = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-3B-Instruct", device_map="cuda"
)
target.eval()
drafter = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B-Instruct", device_map="cuda"
)
drafter.eval()
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
# 2. Prepare input
prompt = "Explain quantum computing in simple terms."
chat = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False)
inputs = tokenizer(text, return_tensors="pt").input_ids[0].tolist()
# 3. Run speculative generation
output_ids, accept_rate = speculative_generate(
inputs,
drafter,
target,
tokenizer=tokenizer,
gamma=5,
logits_processor=GreedyProcessor(),
max_gen_len=100,
eos_tokens_id=[tokenizer.eos_token_id],
)
# 4. Decode output
output_text = tokenizer.decode(output_ids, skip_special_tokens=True)
print(f"Output: {output_text}")
print(f"Acceptance rate: {accept_rate:.3f}")
With KV-Cache and Nucleus Sampling
from utils.logits_processor import NucleusProcessor
output_ids, accept_rate = speculative_generate(
inputs,
drafter,
target,
gamma=4,
logits_processor=NucleusProcessor(temperature=0.7, top_p=0.9),
max_gen_len=50,
eos_tokens_id=[tokenizer.eos_token_id],
use_cache=True,
first_target=True,
)