Implementation:Bigscience workshop Petals InferenceSession
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Computing, NLP, Inference |
| Last Updated | 2026-02-09 14:00 GMT |
Overview
Concrete tool for managing multi-step distributed inference sessions with KV cache persistence provided by the Petals library.
Description
The InferenceSession class provides an interface for multi-step autoregressive inference across a chain of remote transformer block servers. It manages:
- Server connection lifecycle: Opens and closes bidirectional gRPC streams with each server in the route
- KV cache allocation: Requests cache token allocation on each server for the specified max_length
- Step execution: The step() method sends a hidden state tensor through all servers sequentially, each server computing its transformer block and updating its local KV cache
- Fault recovery: If a server fails, _update_sequence() re-routes through alternative servers
- Position tracking: Tracks how many tokens have been processed via the position property
The session is typically obtained via model.inference_session(max_length=...) which returns a context manager.
Usage
Use this class when performing autoregressive text generation with Petals. Always use it as a context manager to ensure proper cleanup of server-side KV caches. The max_length parameter determines the maximum sequence length and must be set before generation begins.
Code Reference
Source Location
- Repository: petals
- File: src/petals/client/inference_session.py (L220-414)
Signature
class InferenceSession:
"""An interface to a multi-step inference session for a sequence of remote transformer blocks"""
def __init__(
self,
sequence_manager: RemoteSequenceManager,
max_length: int,
):
"""
Args:
sequence_manager: RemoteSequenceManager managing server discovery and routing
max_length: Maximum total sequence length for KV cache allocation
"""
def step(
self,
inputs: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
hypo_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Run a single inference step through all remote transformer blocks.
Args:
inputs: Hidden state tensor [batch_size, seq_len, hidden_size]
prompts: Optional prompt tuning embeddings
hypo_ids: Optional hypothesis IDs for beam search reordering
Returns:
Output hidden states after all transformer blocks
"""
def close(self, *exc_details):
"""Close all server sessions and free KV caches."""
@property
def position(self) -> int:
"""Current position in the sequence (number of tokens processed)."""
@position.setter
def position(self, start_from_position: int) -> None:
"""Set position for resuming from a specific point."""
@property
def last_token_id(self) -> Optional[torch.Tensor]:
"""Last generated token ID (used for continuation)."""
Import
from petals.client.inference_session import InferenceSession
# Or via model context manager:
# with model.inference_session(max_length=512) as session:
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| sequence_manager | RemoteSequenceManager | Yes | Manages server discovery and Dijkstra-based routing |
| max_length | int | Yes | Maximum total sequence length for KV cache allocation on servers |
| inputs (step) | torch.Tensor | Yes | Hidden state tensor [batch_size, seq_len, hidden_size] |
| prompts (step) | Optional[torch.Tensor] | No | Prompt tuning embeddings for prefix injection |
| hypo_ids (step) | Optional[torch.Tensor] | No | Hypothesis IDs for beam search cache reordering |
Outputs
| Name | Type | Description |
|---|---|---|
| step() returns | torch.Tensor | Output hidden states after processing through all remote blocks [batch_size, seq_len, hidden_size] |
| position | int | Current sequence position (updated after each step) |
Usage Examples
Basic Generation with Session
from petals import AutoDistributedModelForCausalLM
from transformers import AutoTokenizer
model_name = "petals-team/StableBeluga2"
model = AutoDistributedModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer("A cat sat on", return_tensors="pt")["input_ids"]
# Use inference session for efficient autoregressive generation
with model.inference_session(max_length=100) as session:
outputs = model.generate(
inputs,
session=session,
max_new_tokens=20,
do_sample=True,
temperature=0.7,
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Manual Step-by-Step Generation
import torch
from petals import AutoDistributedModelForCausalLM
from transformers import AutoTokenizer
model = AutoDistributedModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer("Once upon a time", return_tensors="pt")["input_ids"]
with model.inference_session(max_length=50) as session:
# Process the prompt
embs = model.transformer.word_embeddings(inputs)
hidden = session.step(embs)
# Generate tokens one at a time
for _ in range(20):
logits = model.lm_head(hidden[:, -1:, :])
next_token = torch.argmax(logits, dim=-1)
embs = model.transformer.word_embeddings(next_token)
hidden = session.step(embs)