Implementation:Bigscience workshop Petals RemoteGenerationMixin Generate With Session
| Knowledge Sources | |
|---|---|
| Domains | NLP, Dialogue, Text_Generation |
| Last Updated | 2026-02-09 14:00 GMT |
Overview
Concrete tool for multi-turn conversational text generation using persistent inference sessions with KV cache, as provided by Petals' RemoteGenerationMixin.
Description
This is a session-oriented usage of RemoteGenerationMixin.generate() where the same InferenceSession is reused across multiple generation calls for multi-turn dialogue. The key difference from standard single-shot generation:
- Explicit session: The session is created once via model.inference_session(max_length=...) and passed to each generate() call
- Position tracking: The session tracks position across turns, so each new generate call only processes new tokens
- KV cache persistence: Server-side KV caches persist between generate calls, storing attention states from all previous turns
- Prompt tuning integration: If the model was trained with prompt tuning, the prompt embeddings are included in the session context
The underlying implementation uses the same RemoteGenerationMixin.generate() method as single-shot generation but with the session= parameter explicitly passed.
Usage
Use this implementation for building chatbots and interactive dialogue systems with prompt-tuned distributed BLOOM or Llama models. Set max_length large enough to accommodate the full conversation.
Code Reference
Source Location
- Repository: petals
- File: src/petals/client/remote_generation.py (L83-149, RemoteGenerationMixin.generate)
- File: src/petals/client/remote_sequential.py (L85-95, RemoteSequential.inference_session)
- File: src/petals/client/inference_session.py (L220-362, InferenceSession)
Signature
# Session creation (context manager)
model.inference_session(max_length: int) -> ContextManager[InferenceSession]
# Generation with session
model.generate(
inputs: Optional[torch.Tensor] = None,
*args,
session: Optional[InferenceSession] = None, # Pass the persistent session
max_new_tokens: int = ...,
do_sample: bool = True,
temperature: float = ...,
top_p: float = ...,
repetition_penalty: float = ...,
**kwargs,
) -> torch.LongTensor
Import
# Used via the distributed model classes:
from petals.models.bloom.model import DistributedBloomForCausalLM
model = DistributedBloomForCausalLM.from_pretrained(model_name)
# model.inference_session() and model.generate() are inherited from mixins
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| max_length | int | Yes | Maximum total conversation length (all turns combined) |
| inputs | torch.Tensor | Yes (per turn) | New input token IDs for the current dialogue turn |
| session | InferenceSession | Yes | Persistent session reused across turns |
| max_new_tokens | int | No | Tokens to generate for this turn's response |
| do_sample | bool | No | Use sampling for more natural dialogue |
| temperature | float | No | Sampling temperature (0.7-0.9 typical for dialogue) |
| repetition_penalty | float | No | Penalty for repeating tokens (1.2 typical) |
Outputs
| Name | Type | Description |
|---|---|---|
| generated_ids | torch.LongTensor | Generated response token IDs for this turn |
| session state | InferenceSession | Updated session with new position and KV cache state |
Usage Examples
Multi-Turn Chatbot
import torch
from transformers import AutoTokenizer
from petals.models.bloom.model import DistributedBloomForCausalLM
model_name = "bigscience/bloom-7b1-petals"
model = DistributedBloomForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
with model.inference_session(max_length=2048) as session:
# Turn 1
prompt = "Human: What is machine learning?\nAssistant:"
inputs = tokenizer(prompt, return_tensors="pt")["input_ids"]
outputs = model.generate(
inputs,
session=session,
max_new_tokens=100,
do_sample=True,
temperature=0.8,
top_p=0.9,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
# Turn 2 (session preserves KV cache from turn 1)
followup = response + "\nHuman: Can you give an example?\nAssistant:"
inputs2 = tokenizer(followup, return_tensors="pt")["input_ids"]
outputs2 = model.generate(
inputs2,
session=session,
max_new_tokens=100,
do_sample=True,
temperature=0.8,
)
response2 = tokenizer.decode(outputs2[0], skip_special_tokens=True)
print(response2)
Interactive REPL Loop
conversation = ""
with model.inference_session(max_length=4096) as session:
while True:
user_input = input("You: ")
if user_input.lower() == "quit":
break
conversation += f"Human: {user_input}\nAssistant:"
inputs = tokenizer(conversation, return_tensors="pt")["input_ids"]
outputs = model.generate(
inputs,
session=session,
max_new_tokens=200,
do_sample=True,
temperature=0.7,
repetition_penalty=1.2,
)
full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract just the new response
assistant_response = full_text[len(conversation):]
conversation = full_text + "\n"
print(f"Assistant: {assistant_response}")