Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Bigscience workshop Petals DistributedLlamaForSequenceClassification From Pretrained

From Leeroopedia


Knowledge Sources
Domains NLP, Classification, Distributed_Computing
Last Updated 2026-02-09 14:00 GMT

Overview

Concrete tool for loading a distributed Llama model configured for sequence classification with prompt tuning support, provided by Petals.

Description

DistributedLlamaForSequenceClassification extends HuggingFace's LlamaForSequenceClassification with Petals' distributed capabilities via multiple mixins:

  • FromPretrainedMixin: Forces low_cpu_mem_usage=True and torch_dtype="auto"
  • PTuneMixin: Adds prompt tuning embedding layers

The __init__ method:

  1. Calls the parent LlamaForSequenceClassification.__init__ which creates the base Llama model and classification head
  2. The DistributedLlamaModel (used internally) replaces standard transformer layers with RemoteSequential
  3. Calls self.init_prompts(config) to set up prompt tuning embeddings if configured

Usage

Use this class for text classification tasks (SST-2, MNLI, etc.) with large Llama-family models distributed across the Petals network. Configure num_labels, tuning_mode, and pre_seq_len in the model config before loading.

Code Reference

Source Location

  • Repository: petals
  • File: src/petals/models/llama/model.py (L156-174)
  • File: src/petals/client/from_pretrained.py (L17-39)
  • File: src/petals/client/ptune.py (L24-41)

Signature

class DistributedLlamaForSequenceClassification(
    FromPretrainedMixin,
    LlamaForSequenceClassification,
):
    config_class = DistributedLlamaConfig

    def __init__(self, config):
        """
        Initialize distributed Llama for classification.

        Creates:
        - model: DistributedLlamaModel with RemoteSequential layers
        - score: nn.Linear(config.hidden_size, config.num_labels)
        - prompt_embeddings: nn.Embedding (if tuning_mode set)
        - intermediate_prompt_embeddings: nn.Embedding (if deep_ptune)

        Args:
            config: DistributedLlamaConfig with num_labels, tuning_mode, pre_seq_len
        """

    @property
    def transformer(self) -> DistributedLlamaModel:
        """Access the underlying distributed model."""

Import

from petals.models.llama.model import DistributedLlamaForSequenceClassification

I/O Contract

Inputs

Name Type Required Description
model_name_or_path str Yes HuggingFace model name (e.g. "enoch/llama-65b-hf")
num_labels int Yes Number of classification labels (set via config)
tuning_mode str No "ptune" or "deep_ptune" for prompt tuning
pre_seq_len int No Number of trainable prefix prompt tokens

Outputs

Name Type Description
model DistributedLlamaForSequenceClassification Model with RemoteSequential layers, classification head, and prompt embeddings
forward() returns SequenceClassifierOutputWithPast Contains loss (if labels provided), logits [batch, num_labels]

Usage Examples

Loading for SST-2 Binary Classification

from transformers import AutoConfig, AutoTokenizer
from petals.models.llama.model import DistributedLlamaForSequenceClassification

model_name = "enoch/llama-65b-hf"

# Configure for binary classification with prompt tuning
config = AutoConfig.from_pretrained(model_name)
config.num_labels = 2
config.tuning_mode = "ptune"
config.pre_seq_len = 16

# Load distributed classification model
model = DistributedLlamaForSequenceClassification.from_pretrained(
    model_name, config=config
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Check trainable parameters
trainable = {n: p for n, p in model.named_parameters() if p.requires_grad}
for name in trainable:
    print(f"Trainable: {name}")
# Output:
# Trainable: prompt_embeddings.weight
# Trainable: score.weight

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment