Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Bigscience workshop Petals RemoteGenerationMixin Generate With Session

From Leeroopedia


Knowledge Sources
Domains NLP, Dialogue, Text_Generation
Last Updated 2026-02-09 14:00 GMT

Overview

Concrete tool for multi-turn conversational text generation using persistent inference sessions with KV cache, as provided by Petals' RemoteGenerationMixin.

Description

This is a session-oriented usage of RemoteGenerationMixin.generate() where the same InferenceSession is reused across multiple generation calls for multi-turn dialogue. The key difference from standard single-shot generation:

  • Explicit session: The session is created once via model.inference_session(max_length=...) and passed to each generate() call
  • Position tracking: The session tracks position across turns, so each new generate call only processes new tokens
  • KV cache persistence: Server-side KV caches persist between generate calls, storing attention states from all previous turns
  • Prompt tuning integration: If the model was trained with prompt tuning, the prompt embeddings are included in the session context

The underlying implementation uses the same RemoteGenerationMixin.generate() method as single-shot generation but with the session= parameter explicitly passed.

Usage

Use this implementation for building chatbots and interactive dialogue systems with prompt-tuned distributed BLOOM or Llama models. Set max_length large enough to accommodate the full conversation.

Code Reference

Source Location

  • Repository: petals
  • File: src/petals/client/remote_generation.py (L83-149, RemoteGenerationMixin.generate)
  • File: src/petals/client/remote_sequential.py (L85-95, RemoteSequential.inference_session)
  • File: src/petals/client/inference_session.py (L220-362, InferenceSession)

Signature

# Session creation (context manager)
model.inference_session(max_length: int) -> ContextManager[InferenceSession]

# Generation with session
model.generate(
    inputs: Optional[torch.Tensor] = None,
    *args,
    session: Optional[InferenceSession] = None,  # Pass the persistent session
    max_new_tokens: int = ...,
    do_sample: bool = True,
    temperature: float = ...,
    top_p: float = ...,
    repetition_penalty: float = ...,
    **kwargs,
) -> torch.LongTensor

Import

# Used via the distributed model classes:
from petals.models.bloom.model import DistributedBloomForCausalLM
model = DistributedBloomForCausalLM.from_pretrained(model_name)
# model.inference_session() and model.generate() are inherited from mixins

I/O Contract

Inputs

Name Type Required Description
max_length int Yes Maximum total conversation length (all turns combined)
inputs torch.Tensor Yes (per turn) New input token IDs for the current dialogue turn
session InferenceSession Yes Persistent session reused across turns
max_new_tokens int No Tokens to generate for this turn's response
do_sample bool No Use sampling for more natural dialogue
temperature float No Sampling temperature (0.7-0.9 typical for dialogue)
repetition_penalty float No Penalty for repeating tokens (1.2 typical)

Outputs

Name Type Description
generated_ids torch.LongTensor Generated response token IDs for this turn
session state InferenceSession Updated session with new position and KV cache state

Usage Examples

Multi-Turn Chatbot

import torch
from transformers import AutoTokenizer
from petals.models.bloom.model import DistributedBloomForCausalLM

model_name = "bigscience/bloom-7b1-petals"
model = DistributedBloomForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

with model.inference_session(max_length=2048) as session:
    # Turn 1
    prompt = "Human: What is machine learning?\nAssistant:"
    inputs = tokenizer(prompt, return_tensors="pt")["input_ids"]
    outputs = model.generate(
        inputs,
        session=session,
        max_new_tokens=100,
        do_sample=True,
        temperature=0.8,
        top_p=0.9,
    )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(response)

    # Turn 2 (session preserves KV cache from turn 1)
    followup = response + "\nHuman: Can you give an example?\nAssistant:"
    inputs2 = tokenizer(followup, return_tensors="pt")["input_ids"]
    outputs2 = model.generate(
        inputs2,
        session=session,
        max_new_tokens=100,
        do_sample=True,
        temperature=0.8,
    )
    response2 = tokenizer.decode(outputs2[0], skip_special_tokens=True)
    print(response2)

Interactive REPL Loop

conversation = ""
with model.inference_session(max_length=4096) as session:
    while True:
        user_input = input("You: ")
        if user_input.lower() == "quit":
            break

        conversation += f"Human: {user_input}\nAssistant:"
        inputs = tokenizer(conversation, return_tensors="pt")["input_ids"]

        outputs = model.generate(
            inputs,
            session=session,
            max_new_tokens=200,
            do_sample=True,
            temperature=0.7,
            repetition_penalty=1.2,
        )
        full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Extract just the new response
        assistant_response = full_text[len(conversation):]
        conversation = full_text + "\n"
        print(f"Assistant: {assistant_response}")

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment