Principle:Bigscience workshop Petals Speculative Decoding
| Knowledge Sources | |
|---|---|
| Domains | NLP, Text_Generation, Inference_Optimization, Distributed_Computing |
| Last Updated | 2026-02-09 14:00 GMT |
Overview
An inference acceleration technique that uses a small draft model to propose multiple tokens in parallel, then validates them against a large target model in a single forward pass, reducing the number of sequential model calls.
Description
Speculative Decoding addresses the fundamental latency bottleneck of autoregressive generation: each token requires a full sequential forward pass through the model. In a distributed setting like Petals, where each forward pass incurs network round-trip latency, this bottleneck is amplified.
The key insight is that a small, fast draft model (running locally) can predict what a large target model (running on remote servers) would generate with reasonable accuracy. By generating multiple draft tokens at once and then validating the entire batch in a single forward pass of the target model, the amortized cost per token is reduced.
The algorithm:
- The draft model generates k tokens greedily from the current context
- The target model runs a single forward pass on all k draft tokens simultaneously
- Starting from the first draft token, each is compared against the target model's prediction
- All consecutive matching tokens are accepted
- On the first mismatch, the target model's token replaces the draft token
- The process repeats from the new context
In the distributed Petals setting, this is especially beneficial because the single batched validation forward pass through remote servers replaces what would otherwise be k sequential round trips.
Usage
Use this principle when generating text from a distributed model and network latency is the primary bottleneck. It requires a smaller model from the same family available locally. Most effective when the draft model has high agreement with the target model (e.g., Llama-7B drafting for Llama-70B).
Theoretical Basis
Speculative decoding with greedy verification:
Given a draft model and target model , at each iteration:
Failed to parse (syntax error): {\displaystyle \hat{x}_{t+1}, ..., \hat{x}_{t+k} = \text{greedy\_decode}(M_d, x_{1:t}, k) }
Then validate in one pass:
Acceptance length:
Let be the probability that a draft token matches the target model. The expected number of accepted tokens per iteration follows a geometric distribution:
Speedup analysis:
If each distributed forward pass has latency L and the local draft model has negligible latency:
- Standard autoregressive: Cost per token = L
- Speculative (k draft tokens): Cost per token ≈ L / E[accepted]
Pseudo-code logic:
# Abstract speculative decoding algorithm
while not finished:
# Draft phase (local, fast)
draft_tokens = draft_model.generate(context, max_new_tokens=k)
# Verify phase (remote, single batched call)
target_logits = target_model.forward(context + draft_tokens)
# Accept matching prefix
accepted = []
for i in range(k):
target_token = argmax(target_logits[i])
if target_token == draft_tokens[i]:
accepted.append(target_token)
else:
accepted.append(target_token) # Use target's token
break
context = context + accepted
Constraints in Petals implementation:
- Only greedy decoding (do_sample=False) is supported; stochastic speculative sampling requires rejection sampling which is not yet implemented
- Batch size is limited to 1 for token-level comparison
- The draft model must share the same tokenizer as the target model