Implementation:Bigscience workshop Petals DistributedLlamaForSequenceClassification From Pretrained
Appearance
| Knowledge Sources | |
|---|---|
| Domains | NLP, Classification, Distributed_Computing |
| Last Updated | 2026-02-09 14:00 GMT |
Overview
Concrete tool for loading a distributed Llama model configured for sequence classification with prompt tuning support, provided by Petals.
Description
DistributedLlamaForSequenceClassification extends HuggingFace's LlamaForSequenceClassification with Petals' distributed capabilities via multiple mixins:
- FromPretrainedMixin: Forces low_cpu_mem_usage=True and torch_dtype="auto"
- PTuneMixin: Adds prompt tuning embedding layers
The __init__ method:
- Calls the parent LlamaForSequenceClassification.__init__ which creates the base Llama model and classification head
- The DistributedLlamaModel (used internally) replaces standard transformer layers with RemoteSequential
- Calls self.init_prompts(config) to set up prompt tuning embeddings if configured
Usage
Use this class for text classification tasks (SST-2, MNLI, etc.) with large Llama-family models distributed across the Petals network. Configure num_labels, tuning_mode, and pre_seq_len in the model config before loading.
Code Reference
Source Location
- Repository: petals
- File: src/petals/models/llama/model.py (L156-174)
- File: src/petals/client/from_pretrained.py (L17-39)
- File: src/petals/client/ptune.py (L24-41)
Signature
class DistributedLlamaForSequenceClassification(
FromPretrainedMixin,
LlamaForSequenceClassification,
):
config_class = DistributedLlamaConfig
def __init__(self, config):
"""
Initialize distributed Llama for classification.
Creates:
- model: DistributedLlamaModel with RemoteSequential layers
- score: nn.Linear(config.hidden_size, config.num_labels)
- prompt_embeddings: nn.Embedding (if tuning_mode set)
- intermediate_prompt_embeddings: nn.Embedding (if deep_ptune)
Args:
config: DistributedLlamaConfig with num_labels, tuning_mode, pre_seq_len
"""
@property
def transformer(self) -> DistributedLlamaModel:
"""Access the underlying distributed model."""
Import
from petals.models.llama.model import DistributedLlamaForSequenceClassification
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model_name_or_path | str | Yes | HuggingFace model name (e.g. "enoch/llama-65b-hf") |
| num_labels | int | Yes | Number of classification labels (set via config) |
| tuning_mode | str | No | "ptune" or "deep_ptune" for prompt tuning |
| pre_seq_len | int | No | Number of trainable prefix prompt tokens |
Outputs
| Name | Type | Description |
|---|---|---|
| model | DistributedLlamaForSequenceClassification | Model with RemoteSequential layers, classification head, and prompt embeddings |
| forward() returns | SequenceClassifierOutputWithPast | Contains loss (if labels provided), logits [batch, num_labels] |
Usage Examples
Loading for SST-2 Binary Classification
from transformers import AutoConfig, AutoTokenizer
from petals.models.llama.model import DistributedLlamaForSequenceClassification
model_name = "enoch/llama-65b-hf"
# Configure for binary classification with prompt tuning
config = AutoConfig.from_pretrained(model_name)
config.num_labels = 2
config.tuning_mode = "ptune"
config.pre_seq_len = 16
# Load distributed classification model
model = DistributedLlamaForSequenceClassification.from_pretrained(
model_name, config=config
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Check trainable parameters
trainable = {n: p for n, p in model.named_parameters() if p.requires_grad}
for name in trainable:
print(f"Trainable: {name}")
# Output:
# Trainable: prompt_embeddings.weight
# Trainable: score.weight
Related Pages
Implements Principle
Requires Environment
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment