Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Bigscience workshop Petals DistributedLlamaForSpeculativeGeneration

From Leeroopedia


Knowledge Sources
Domains Distributed_Computing, NLP, Text_Generation, Inference_Optimization
Last Updated 2026-02-09 14:00 GMT

Overview

Concrete tool for accelerating distributed text generation by using a small local Llama model to speculatively draft tokens that are batch-validated by the full distributed model.

Description

DistributedLlamaForSpeculativeGeneration pairs a DistributedLlamaForCausalLM (the large target model running across remote Petals servers) with a local LlamaForCausalLM (a small draft model). In each iteration, the small model greedily generates a batch of speculative tokens locally, the full distributed model validates all of them in a single forward pass, and only tokens where both models agree are accepted. On the first disagreement, the large model's token replaces the draft token and iteration continues.

This reduces the number of expensive distributed forward passes from one-per-token to approximately one per speculative_inference_iteration_size tokens (when the draft model is accurate), significantly reducing generation latency over high-latency networks.

Usage

Import this class when you need faster autoregressive generation from a distributed Llama model and have a smaller Llama model available locally as a draft model. This is most effective when network latency is high relative to local compute cost, and when the draft model has reasonable agreement with the target model.

Code Reference

Source Location

Signature

class DistributedLlamaForSpeculativeGeneration(DistributedLlamaForCausalLM, GenerationMixin):
    def __init__(
        self,
        config: DistributedLlamaConfig,
        small_model: LlamaForCausalLM,
    ):
        """
        Args:
            config: Distributed Llama configuration.
            small_model: A small local Llama model used as the draft model
                         for speculative token generation.
        """

    def _sample(
        self,
        input_ids: torch.LongTensor,
        logits_processor: LogitsProcessorList,
        stopping_criteria: StoppingCriteriaList,
        generation_config: GenerationConfig,
        synced_gpus: bool,
        streamer: Optional["BaseStreamer"],
        logits_warper: Optional[LogitsProcessorList],
        speculative_inference_iteration_size: int = 10,
        **model_kwargs,
    ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
        """
        Speculative sampling loop. Generates tokens using draft model,
        validates with target model, accepts matching prefix.

        Args:
            input_ids: Initial token sequence.
            logits_processor: Transformations applied to logits.
            stopping_criteria: Conditions for stopping generation.
            generation_config: HuggingFace generation configuration.
            synced_gpus: Must be False (not supported).
            streamer: Optional streamer for token-by-token output.
            logits_warper: Optional logits warper (not used with greedy).
            speculative_inference_iteration_size: Number of tokens to
                draft per iteration (default: 10).
        """

Import

from petals.models.llama.speculative_model import DistributedLlamaForSpeculativeGeneration

I/O Contract

Inputs

Name Type Required Description
config DistributedLlamaConfig Yes Distributed Llama model configuration
small_model LlamaForCausalLM Yes Local draft model for speculative token generation
input_ids torch.LongTensor Yes Input token IDs of shape (batch_size, seq_len)
speculative_inference_iteration_size int No Number of tokens to draft per iteration (default: 10)
generation_config GenerationConfig Yes Must have do_sample=False (greedy only)

Outputs

Name Type Description
input_ids torch.LongTensor Full generated sequence including prompt and generated tokens

Constraints

  • do_sample must be False (only greedy decoding is supported)
  • synced_gpus must be False
  • return_dict_in_generate must be False
  • Batch size is limited to 1 for token-level comparison

Usage Examples

Speculative Generation with Llama

import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
from petals import AutoDistributedModelForCausalLM
from petals.models.llama.speculative_model import DistributedLlamaForSpeculativeGeneration

# 1. Load the small draft model locally
small_model = LlamaForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    torch_dtype=torch.float16,
    device_map="auto",
)

# 2. Load the large distributed model config
large_model = AutoDistributedModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-70b-hf",
)

# 3. Create the speculative generation wrapper
spec_model = DistributedLlamaForSpeculativeGeneration(
    large_model.config,
    small_model=small_model,
)
# Copy over remote sequential and other components from the large model
spec_model.model = large_model.model
spec_model.lm_head = large_model.lm_head

# 4. Generate with speculative decoding
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-70b-hf")
inputs = tokenizer("The future of AI is", return_tensors="pt")

with spec_model.inference_session(max_length=100) as session:
    spec_model.active_session = session
    outputs = spec_model.generate(
        **inputs,
        max_new_tokens=50,
        do_sample=False,
    )

print(tokenizer.decode(outputs[0]))

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment