Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Bigscience workshop Petals InferenceSession

From Leeroopedia


Knowledge Sources
Domains Distributed_Computing, NLP, Inference
Last Updated 2026-02-09 14:00 GMT

Overview

Concrete tool for managing multi-step distributed inference sessions with KV cache persistence provided by the Petals library.

Description

The InferenceSession class provides an interface for multi-step autoregressive inference across a chain of remote transformer block servers. It manages:

  • Server connection lifecycle: Opens and closes bidirectional gRPC streams with each server in the route
  • KV cache allocation: Requests cache token allocation on each server for the specified max_length
  • Step execution: The step() method sends a hidden state tensor through all servers sequentially, each server computing its transformer block and updating its local KV cache
  • Fault recovery: If a server fails, _update_sequence() re-routes through alternative servers
  • Position tracking: Tracks how many tokens have been processed via the position property

The session is typically obtained via model.inference_session(max_length=...) which returns a context manager.

Usage

Use this class when performing autoregressive text generation with Petals. Always use it as a context manager to ensure proper cleanup of server-side KV caches. The max_length parameter determines the maximum sequence length and must be set before generation begins.

Code Reference

Source Location

  • Repository: petals
  • File: src/petals/client/inference_session.py (L220-414)

Signature

class InferenceSession:
    """An interface to a multi-step inference session for a sequence of remote transformer blocks"""

    def __init__(
        self,
        sequence_manager: RemoteSequenceManager,
        max_length: int,
    ):
        """
        Args:
            sequence_manager: RemoteSequenceManager managing server discovery and routing
            max_length: Maximum total sequence length for KV cache allocation
        """

    def step(
        self,
        inputs: torch.Tensor,
        prompts: Optional[torch.Tensor] = None,
        hypo_ids: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Run a single inference step through all remote transformer blocks.

        Args:
            inputs: Hidden state tensor [batch_size, seq_len, hidden_size]
            prompts: Optional prompt tuning embeddings
            hypo_ids: Optional hypothesis IDs for beam search reordering
        Returns:
            Output hidden states after all transformer blocks
        """

    def close(self, *exc_details):
        """Close all server sessions and free KV caches."""

    @property
    def position(self) -> int:
        """Current position in the sequence (number of tokens processed)."""

    @position.setter
    def position(self, start_from_position: int) -> None:
        """Set position for resuming from a specific point."""

    @property
    def last_token_id(self) -> Optional[torch.Tensor]:
        """Last generated token ID (used for continuation)."""

Import

from petals.client.inference_session import InferenceSession
# Or via model context manager:
# with model.inference_session(max_length=512) as session:

I/O Contract

Inputs

Name Type Required Description
sequence_manager RemoteSequenceManager Yes Manages server discovery and Dijkstra-based routing
max_length int Yes Maximum total sequence length for KV cache allocation on servers
inputs (step) torch.Tensor Yes Hidden state tensor [batch_size, seq_len, hidden_size]
prompts (step) Optional[torch.Tensor] No Prompt tuning embeddings for prefix injection
hypo_ids (step) Optional[torch.Tensor] No Hypothesis IDs for beam search cache reordering

Outputs

Name Type Description
step() returns torch.Tensor Output hidden states after processing through all remote blocks [batch_size, seq_len, hidden_size]
position int Current sequence position (updated after each step)

Usage Examples

Basic Generation with Session

from petals import AutoDistributedModelForCausalLM
from transformers import AutoTokenizer

model_name = "petals-team/StableBeluga2"
model = AutoDistributedModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

inputs = tokenizer("A cat sat on", return_tensors="pt")["input_ids"]

# Use inference session for efficient autoregressive generation
with model.inference_session(max_length=100) as session:
    outputs = model.generate(
        inputs,
        session=session,
        max_new_tokens=20,
        do_sample=True,
        temperature=0.7,
    )

print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Manual Step-by-Step Generation

import torch
from petals import AutoDistributedModelForCausalLM
from transformers import AutoTokenizer

model = AutoDistributedModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

inputs = tokenizer("Once upon a time", return_tensors="pt")["input_ids"]

with model.inference_session(max_length=50) as session:
    # Process the prompt
    embs = model.transformer.word_embeddings(inputs)
    hidden = session.step(embs)

    # Generate tokens one at a time
    for _ in range(20):
        logits = model.lm_head(hidden[:, -1:, :])
        next_token = torch.argmax(logits, dim=-1)
        embs = model.transformer.word_embeddings(next_token)
        hidden = session.step(embs)

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment