Principle:Bigscience workshop Petals Autoregressive Generation
| Knowledge Sources | |
|---|---|
| Domains | NLP, Text_Generation, Distributed_Computing |
| Last Updated | 2026-02-09 14:00 GMT |
Overview
A text generation method where tokens are produced one at a time, with each new token conditioned on all previously generated tokens, adapted for distributed execution across remote transformer block servers.
Description
Autoregressive generation is the standard method for producing text from causal language models. At each step, the model computes a probability distribution over the vocabulary for the next token, conditioned on the full sequence so far. A sampling strategy (greedy, top-k, top-p, temperature) selects the next token, which is appended to the sequence.
In Petals' distributed setting, this process is orchestrated by RemoteGenerationMixin, which bridges HuggingFace's GenerationMixin.generate() API with Petals' InferenceSession. The key adaptation is that the forward pass through transformer blocks happens across remote servers rather than locally, while the embedding lookup, LM head projection, and sampling logic remain on the client.
Usage
Use this principle when generating text completions, dialogue responses, or any form of open-ended text from a distributed large language model. This is the primary output-producing step in all Petals client workflows.
Theoretical Basis
Autoregressive factorization:
At each step t, the model produces logits over vocabulary V:
Sampling strategies:
- Greedy:
- Temperature:
- Top-k: Sample from the k highest-probability tokens
- Top-p (nucleus): Sample from the smallest set of tokens whose cumulative probability exceeds p
Distributed adaptation:
# Abstract distributed generation loop
with inference_session(max_length) as session:
hidden = embed(prompt_tokens)
hidden = session.step(hidden) # Remote transformer blocks
for step in range(max_new_tokens):
logits = lm_head(hidden[:, -1:, :])
next_token = sample(logits, temperature, top_k, top_p)
hidden = embed(next_token)
hidden = session.step(hidden) # Only processes new token (KV cached)