Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Bigscience workshop Petals DistributedBloomForCausalLM From Pretrained

From Leeroopedia


Knowledge Sources
Domains NLP, Dialogue, Distributed_Computing
Last Updated 2026-02-09 14:00 GMT

Overview

Concrete tool for loading a distributed BLOOM causal language model with prompt tuning for dialogue generation, provided by Petals.

Description

DistributedBloomForCausalLM extends HuggingFace's BloomForCausalLM with distributed and prompt tuning capabilities via multiple mixins:

  • FromPretrainedMixin: Forces efficient loading defaults
  • RemoteGenerationMixin: Bridges HuggingFace generate() with distributed sessions
  • PTuneMixin: (via DistributedBloomModel) Adds prompt tuning embeddings

The __init__ method creates:

  • transformer: DistributedBloomModel with RemoteSequential ``.h`` layers
  • lm_head: LMHead for token prediction (loaded locally)
  • prompt_embeddings: Trainable nn.Embedding if tuning_mode is configured

The model supports prepare_inputs_for_generation which integrates with HuggingFace's generation pipeline and the distributed session cache system.

Usage

Use this class for chatbot development with large BLOOM models. Supports both prompt-tuned training on dialogue data and interactive generation with multi-turn session management.

Code Reference

Source Location

  • Repository: petals
  • File: src/petals/models/bloom/model.py (L111-181)
  • File: src/petals/client/from_pretrained.py (L17-39)
  • File: src/petals/client/ptune.py (L24-41)

Signature

class DistributedBloomForCausalLM(
    FromPretrainedMixin,
    RemoteGenerationMixin,
    BloomForCausalLM,
):
    _supports_cache_class = True
    config_class = DistributedBloomConfig

    def __init__(self, config: DistributedBloomConfig):
        """
        Initialize distributed BLOOM for causal LM.

        Creates:
        - transformer: DistributedBloomModel with RemoteSequential .h layers
        - lm_head: LMHead for token prediction
        - prompt_embeddings: nn.Embedding (if tuning_mode configured)

        Args:
            config: DistributedBloomConfig with optional tuning_mode, pre_seq_len
        """

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        inputs_embeds=None,
        **kwargs,
    ) -> dict:
        """Prepare inputs for HuggingFace generation pipeline."""

Import

from petals.models.bloom.model import DistributedBloomForCausalLM

I/O Contract

Inputs

Name Type Required Description
model_name_or_path str Yes HuggingFace BLOOM model name (e.g. "bigscience/bloom-7b1-petals")
tuning_mode str No "ptune" or "deep_ptune" (set in config)
pre_seq_len int No Number of trainable prefix prompt tokens

Outputs

Name Type Description
model DistributedBloomForCausalLM Distributed BLOOM model with RemoteSequential layers, LM head, and optional prompt embeddings
forward() returns CausalLMOutputWithCrossAttentions Contains loss (if labels provided), logits [batch, seq_len, vocab_size]

Usage Examples

Loading for Chatbot Training

from transformers import AutoConfig, AutoTokenizer
from petals.models.bloom.model import DistributedBloomForCausalLM

model_name = "bigscience/bloom-7b1-petals"

# Configure with prompt tuning for dialogue
config = AutoConfig.from_pretrained(model_name)
config.tuning_mode = "ptune"
config.pre_seq_len = 16

model = DistributedBloomForCausalLM.from_pretrained(model_name, config=config)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Training: only prompt_embeddings are trainable
optimizer = torch.optim.AdamW(
    [p for p in model.parameters() if p.requires_grad],
    lr=1e-3,
)

Interactive Generation

# After training, use for multi-turn dialogue
with model.inference_session(max_length=512) as session:
    user_input = "Hello! What can you help me with?"
    prompt = f"Human: {user_input}\nAssistant:"
    inputs = tokenizer(prompt, return_tensors="pt")["input_ids"]

    outputs = model.generate(
        inputs,
        session=session,
        max_new_tokens=100,
        do_sample=True,
        temperature=0.8,
        top_p=0.9,
    )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment