Implementation:Bigscience workshop Petals DistributedBloomForCausalLM From Pretrained
| Knowledge Sources | |
|---|---|
| Domains | NLP, Dialogue, Distributed_Computing |
| Last Updated | 2026-02-09 14:00 GMT |
Overview
Concrete tool for loading a distributed BLOOM causal language model with prompt tuning for dialogue generation, provided by Petals.
Description
DistributedBloomForCausalLM extends HuggingFace's BloomForCausalLM with distributed and prompt tuning capabilities via multiple mixins:
- FromPretrainedMixin: Forces efficient loading defaults
- RemoteGenerationMixin: Bridges HuggingFace generate() with distributed sessions
- PTuneMixin: (via DistributedBloomModel) Adds prompt tuning embeddings
The __init__ method creates:
- transformer: DistributedBloomModel with RemoteSequential ``.h`` layers
- lm_head: LMHead for token prediction (loaded locally)
- prompt_embeddings: Trainable nn.Embedding if tuning_mode is configured
The model supports prepare_inputs_for_generation which integrates with HuggingFace's generation pipeline and the distributed session cache system.
Usage
Use this class for chatbot development with large BLOOM models. Supports both prompt-tuned training on dialogue data and interactive generation with multi-turn session management.
Code Reference
Source Location
- Repository: petals
- File: src/petals/models/bloom/model.py (L111-181)
- File: src/petals/client/from_pretrained.py (L17-39)
- File: src/petals/client/ptune.py (L24-41)
Signature
class DistributedBloomForCausalLM(
FromPretrainedMixin,
RemoteGenerationMixin,
BloomForCausalLM,
):
_supports_cache_class = True
config_class = DistributedBloomConfig
def __init__(self, config: DistributedBloomConfig):
"""
Initialize distributed BLOOM for causal LM.
Creates:
- transformer: DistributedBloomModel with RemoteSequential .h layers
- lm_head: LMHead for token prediction
- prompt_embeddings: nn.Embedding (if tuning_mode configured)
Args:
config: DistributedBloomConfig with optional tuning_mode, pre_seq_len
"""
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs,
) -> dict:
"""Prepare inputs for HuggingFace generation pipeline."""
Import
from petals.models.bloom.model import DistributedBloomForCausalLM
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model_name_or_path | str | Yes | HuggingFace BLOOM model name (e.g. "bigscience/bloom-7b1-petals") |
| tuning_mode | str | No | "ptune" or "deep_ptune" (set in config) |
| pre_seq_len | int | No | Number of trainable prefix prompt tokens |
Outputs
| Name | Type | Description |
|---|---|---|
| model | DistributedBloomForCausalLM | Distributed BLOOM model with RemoteSequential layers, LM head, and optional prompt embeddings |
| forward() returns | CausalLMOutputWithCrossAttentions | Contains loss (if labels provided), logits [batch, seq_len, vocab_size] |
Usage Examples
Loading for Chatbot Training
from transformers import AutoConfig, AutoTokenizer
from petals.models.bloom.model import DistributedBloomForCausalLM
model_name = "bigscience/bloom-7b1-petals"
# Configure with prompt tuning for dialogue
config = AutoConfig.from_pretrained(model_name)
config.tuning_mode = "ptune"
config.pre_seq_len = 16
model = DistributedBloomForCausalLM.from_pretrained(model_name, config=config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Training: only prompt_embeddings are trainable
optimizer = torch.optim.AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=1e-3,
)
Interactive Generation
# After training, use for multi-turn dialogue
with model.inference_session(max_length=512) as session:
user_input = "Hello! What can you help me with?"
prompt = f"Human: {user_input}\nAssistant:"
inputs = tokenizer(prompt, return_tensors="pt")["input_ids"]
outputs = model.generate(
inputs,
session=session,
max_new_tokens=100,
do_sample=True,
temperature=0.8,
top_p=0.9,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)