Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Bigscience workshop Petals RemoteGenerationMixin Generate

From Leeroopedia


Knowledge Sources
Domains NLP, Text_Generation, Distributed_Computing
Last Updated 2026-02-09 14:00 GMT

Overview

Concrete tool for generating text from distributed models by bridging HuggingFace's generate() API with Petals' distributed inference sessions.

Description

RemoteGenerationMixin is a mixin class that overrides HuggingFace's GenerationMixin.generate() to work with distributed inference. It:

  • Creates an InferenceSession if one is not provided
  • Wraps the session in a RemotePastKeyValues object that mimics HuggingFace's cache interface
  • Delegates to the parent generate() method which handles all sampling strategies (greedy, beam search, top-k, top-p, etc.)
  • Fixes generation kwargs to ensure compatibility with the distributed setup (e.g., adjusts max_length based on session position)

The mixin is inherited by all distributed model classes (DistributedLlamaForCausalLM, DistributedBloomForCausalLM, etc.).

Usage

Use this implementation to generate text from any Petals distributed model. It supports all standard HuggingFace generation parameters. For best performance, pass an existing InferenceSession to avoid creating a new session per call.

Code Reference

Source Location

  • Repository: petals
  • File: src/petals/client/remote_generation.py (L83-149)

Signature

class RemoteGenerationMixin(_SkipTokensMixin):
    def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        *args,
        session: Optional[InferenceSession] = None,
        **kwargs,
    ) -> Union[torch.LongTensor, ModelOutput]:
        """
        Generate sequences using the distributed model.

        Args:
            inputs: Input token IDs tensor [batch_size, seq_len], or None for continuation
            session: Optional InferenceSession for KV cache reuse across calls.
                     If None, a new session is created automatically.
            max_new_tokens: Number of new tokens to generate
            max_length: Maximum total sequence length
            do_sample: Whether to use sampling (True) or greedy decoding (False)
            temperature: Sampling temperature (higher = more random)
            top_k: Top-k sampling parameter
            top_p: Top-p (nucleus) sampling parameter
            repetition_penalty: Penalty for repeating tokens
        Returns:
            Generated token IDs tensor [batch_size, total_seq_len]
        """

Import

# Used via the distributed model classes:
from petals import AutoDistributedModelForCausalLM
model = AutoDistributedModelForCausalLM.from_pretrained(model_name)
model.generate(...)  # RemoteGenerationMixin.generate is called

I/O Contract

Inputs

Name Type Required Description
inputs Optional[torch.Tensor] No Input token IDs [batch_size, seq_len]; None for continuation from session state
session Optional[InferenceSession] No Existing session for KV cache reuse; if None, one is created automatically
max_new_tokens int No Number of new tokens to generate
max_length int No Maximum total sequence length (prompt + generated)
do_sample bool No True for sampling, False for greedy decoding
temperature float No Sampling temperature (default 1.0)
top_k int No Top-k sampling (default 50)
top_p float No Nucleus sampling threshold (default 1.0)

Outputs

Name Type Description
sequences torch.LongTensor Generated token IDs [batch_size, total_seq_len] including prompt
session state InferenceSession Session updated with new position and output_ids for multi-turn continuation

Usage Examples

Basic Text Generation

from petals import AutoDistributedModelForCausalLM
from transformers import AutoTokenizer

model_name = "petals-team/StableBeluga2"
model = AutoDistributedModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

inputs = tokenizer("The meaning of life is", return_tensors="pt")["input_ids"]

# Simple generation (session created automatically)
outputs = model.generate(inputs, max_new_tokens=50, do_sample=True, temperature=0.7)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Generation with Explicit Session

# Reuse session for multi-turn generation
with model.inference_session(max_length=512) as session:
    # First turn
    inputs1 = tokenizer("User: Hello!\nAssistant:", return_tensors="pt")["input_ids"]
    out1 = model.generate(inputs1, session=session, max_new_tokens=50)
    response1 = tokenizer.decode(out1[0])

    # Second turn (session preserves KV cache)
    inputs2 = tokenizer(response1 + "\nUser: Tell me more.\nAssistant:", return_tensors="pt")["input_ids"]
    out2 = model.generate(inputs2, session=session, max_new_tokens=50)

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment