Implementation:Bigscience workshop Petals DistributedLlamaForSpeculativeGeneration
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Computing, NLP, Text_Generation, Inference_Optimization |
| Last Updated | 2026-02-09 14:00 GMT |
Overview
Concrete tool for accelerating distributed text generation by using a small local Llama model to speculatively draft tokens that are batch-validated by the full distributed model.
Description
DistributedLlamaForSpeculativeGeneration pairs a DistributedLlamaForCausalLM (the large target model running across remote Petals servers) with a local LlamaForCausalLM (a small draft model). In each iteration, the small model greedily generates a batch of speculative tokens locally, the full distributed model validates all of them in a single forward pass, and only tokens where both models agree are accepted. On the first disagreement, the large model's token replaces the draft token and iteration continues.
This reduces the number of expensive distributed forward passes from one-per-token to approximately one per speculative_inference_iteration_size tokens (when the draft model is accurate), significantly reducing generation latency over high-latency networks.
Usage
Import this class when you need faster autoregressive generation from a distributed Llama model and have a smaller Llama model available locally as a draft model. This is most effective when network latency is high relative to local compute cost, and when the draft model has reasonable agreement with the target model.
Code Reference
Source Location
- Repository: Bigscience_workshop_Petals
- File: src/petals/models/llama/speculative_model.py
- Lines: 1-111
Signature
class DistributedLlamaForSpeculativeGeneration(DistributedLlamaForCausalLM, GenerationMixin):
def __init__(
self,
config: DistributedLlamaConfig,
small_model: LlamaForCausalLM,
):
"""
Args:
config: Distributed Llama configuration.
small_model: A small local Llama model used as the draft model
for speculative token generation.
"""
def _sample(
self,
input_ids: torch.LongTensor,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
streamer: Optional["BaseStreamer"],
logits_warper: Optional[LogitsProcessorList],
speculative_inference_iteration_size: int = 10,
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
"""
Speculative sampling loop. Generates tokens using draft model,
validates with target model, accepts matching prefix.
Args:
input_ids: Initial token sequence.
logits_processor: Transformations applied to logits.
stopping_criteria: Conditions for stopping generation.
generation_config: HuggingFace generation configuration.
synced_gpus: Must be False (not supported).
streamer: Optional streamer for token-by-token output.
logits_warper: Optional logits warper (not used with greedy).
speculative_inference_iteration_size: Number of tokens to
draft per iteration (default: 10).
"""
Import
from petals.models.llama.speculative_model import DistributedLlamaForSpeculativeGeneration
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| config | DistributedLlamaConfig | Yes | Distributed Llama model configuration |
| small_model | LlamaForCausalLM | Yes | Local draft model for speculative token generation |
| input_ids | torch.LongTensor | Yes | Input token IDs of shape (batch_size, seq_len) |
| speculative_inference_iteration_size | int | No | Number of tokens to draft per iteration (default: 10) |
| generation_config | GenerationConfig | Yes | Must have do_sample=False (greedy only) |
Outputs
| Name | Type | Description |
|---|---|---|
| input_ids | torch.LongTensor | Full generated sequence including prompt and generated tokens |
Constraints
- do_sample must be False (only greedy decoding is supported)
- synced_gpus must be False
- return_dict_in_generate must be False
- Batch size is limited to 1 for token-level comparison
Usage Examples
Speculative Generation with Llama
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
from petals import AutoDistributedModelForCausalLM
from petals.models.llama.speculative_model import DistributedLlamaForSpeculativeGeneration
# 1. Load the small draft model locally
small_model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
device_map="auto",
)
# 2. Load the large distributed model config
large_model = AutoDistributedModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b-hf",
)
# 3. Create the speculative generation wrapper
spec_model = DistributedLlamaForSpeculativeGeneration(
large_model.config,
small_model=small_model,
)
# Copy over remote sequential and other components from the large model
spec_model.model = large_model.model
spec_model.lm_head = large_model.lm_head
# 4. Generate with speculative decoding
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-70b-hf")
inputs = tokenizer("The future of AI is", return_tensors="pt")
with spec_model.inference_session(max_length=100) as session:
spec_model.active_session = session
outputs = spec_model.generate(
**inputs,
max_new_tokens=50,
do_sample=False,
)
print(tokenizer.decode(outputs[0]))