Implementation:Bigscience workshop Petals RemoteGenerationMixin Generate
| Knowledge Sources | |
|---|---|
| Domains | NLP, Text_Generation, Distributed_Computing |
| Last Updated | 2026-02-09 14:00 GMT |
Overview
Concrete tool for generating text from distributed models by bridging HuggingFace's generate() API with Petals' distributed inference sessions.
Description
RemoteGenerationMixin is a mixin class that overrides HuggingFace's GenerationMixin.generate() to work with distributed inference. It:
- Creates an InferenceSession if one is not provided
- Wraps the session in a RemotePastKeyValues object that mimics HuggingFace's cache interface
- Delegates to the parent generate() method which handles all sampling strategies (greedy, beam search, top-k, top-p, etc.)
- Fixes generation kwargs to ensure compatibility with the distributed setup (e.g., adjusts max_length based on session position)
The mixin is inherited by all distributed model classes (DistributedLlamaForCausalLM, DistributedBloomForCausalLM, etc.).
Usage
Use this implementation to generate text from any Petals distributed model. It supports all standard HuggingFace generation parameters. For best performance, pass an existing InferenceSession to avoid creating a new session per call.
Code Reference
Source Location
- Repository: petals
- File: src/petals/client/remote_generation.py (L83-149)
Signature
class RemoteGenerationMixin(_SkipTokensMixin):
def generate(
self,
inputs: Optional[torch.Tensor] = None,
*args,
session: Optional[InferenceSession] = None,
**kwargs,
) -> Union[torch.LongTensor, ModelOutput]:
"""
Generate sequences using the distributed model.
Args:
inputs: Input token IDs tensor [batch_size, seq_len], or None for continuation
session: Optional InferenceSession for KV cache reuse across calls.
If None, a new session is created automatically.
max_new_tokens: Number of new tokens to generate
max_length: Maximum total sequence length
do_sample: Whether to use sampling (True) or greedy decoding (False)
temperature: Sampling temperature (higher = more random)
top_k: Top-k sampling parameter
top_p: Top-p (nucleus) sampling parameter
repetition_penalty: Penalty for repeating tokens
Returns:
Generated token IDs tensor [batch_size, total_seq_len]
"""
Import
# Used via the distributed model classes:
from petals import AutoDistributedModelForCausalLM
model = AutoDistributedModelForCausalLM.from_pretrained(model_name)
model.generate(...) # RemoteGenerationMixin.generate is called
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| inputs | Optional[torch.Tensor] | No | Input token IDs [batch_size, seq_len]; None for continuation from session state |
| session | Optional[InferenceSession] | No | Existing session for KV cache reuse; if None, one is created automatically |
| max_new_tokens | int | No | Number of new tokens to generate |
| max_length | int | No | Maximum total sequence length (prompt + generated) |
| do_sample | bool | No | True for sampling, False for greedy decoding |
| temperature | float | No | Sampling temperature (default 1.0) |
| top_k | int | No | Top-k sampling (default 50) |
| top_p | float | No | Nucleus sampling threshold (default 1.0) |
Outputs
| Name | Type | Description |
|---|---|---|
| sequences | torch.LongTensor | Generated token IDs [batch_size, total_seq_len] including prompt |
| session state | InferenceSession | Session updated with new position and output_ids for multi-turn continuation |
Usage Examples
Basic Text Generation
from petals import AutoDistributedModelForCausalLM
from transformers import AutoTokenizer
model_name = "petals-team/StableBeluga2"
model = AutoDistributedModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer("The meaning of life is", return_tensors="pt")["input_ids"]
# Simple generation (session created automatically)
outputs = model.generate(inputs, max_new_tokens=50, do_sample=True, temperature=0.7)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Generation with Explicit Session
# Reuse session for multi-turn generation
with model.inference_session(max_length=512) as session:
# First turn
inputs1 = tokenizer("User: Hello!\nAssistant:", return_tensors="pt")["input_ids"]
out1 = model.generate(inputs1, session=session, max_new_tokens=50)
response1 = tokenizer.decode(out1[0])
# Second turn (session preserves KV cache)
inputs2 = tokenizer(response1 + "\nUser: Tell me more.\nAssistant:", return_tensors="pt")["input_ids"]
out2 = model.generate(inputs2, session=session, max_new_tokens=50)