Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Bigscience workshop Petals PTuneMixin

From Leeroopedia


Knowledge Sources
Domains NLP, Parameter_Efficient_Finetuning
Last Updated 2026-02-09 14:00 GMT

Overview

Concrete tool for initializing and managing trainable prompt embeddings in distributed Petals models for parameter-efficient fine-tuning.

Description

PTuneMixin is a mixin class that adds prompt tuning capabilities to distributed model classes. It provides two methods:

  • init_prompts(config): Initializes prompt embedding layers based on the model's PTuneConfig settings. Creates prompt_embeddings (nn.Embedding) and optionally intermediate_prompt_embeddings for deep prompt tuning.
  • get_prompt(batch_size): Returns the prompt tensor to prepend to input hidden states, properly expanded for the batch.

The mixin reads configuration from PTuneConfig, a dataclass with:

  • pre_seq_len: Number of prefix prompt tokens (default 0, meaning no prompt tuning)
  • tuning_mode: Either "ptune" (shallow) or "deep_ptune" (deep, all layers)

Usage

This mixin is automatically mixed into distributed model classes. Configure it by setting tuning_mode and pre_seq_len in the model config before loading. The prompt embeddings become part of the model's trainable parameters while all other weights remain frozen.

Code Reference

Source Location

  • Repository: petals
  • File: src/petals/client/ptune.py (L24-62)

Signature

@dataclasses.dataclass
class PTuneConfig:
    pre_seq_len: int = 0
    tuning_mode: Optional[str] = None  # "ptune" or "deep_ptune"

class PTuneMixin:
    _keys_to_ignore_on_load_missing = [r"(intermediate_)?prompt_embeddings\.weight$"]

    def init_prompts(self, config: PretrainedConfig) -> None:
        """
        Initialize prompt tuning embeddings based on config.

        If tuning_mode is set and pre_seq_len > 0, creates:
        - prompt_embeddings: nn.Embedding(pre_seq_len, hidden_size)
        - intermediate_prompt_embeddings: nn.Embedding(pre_seq_len * num_layers, hidden_size)
          (only if tuning_mode == "deep_ptune")

        Args:
            config: Model configuration with PTuneConfig fields
        """

    def get_prompt(self, batch_size: int):
        """
        Get prompt embeddings expanded for the given batch size.

        Args:
            batch_size: Number of sequences in the batch
        Returns:
            Tuple of (prompts, intermediate_prompts) tensors or (None, None)
            prompts shape: [batch_size, pre_seq_len, hidden_size]
            intermediate_prompts shape: [batch_size, pre_seq_len * num_layers, hidden_size] or None
        """

Import

# PTuneMixin is used internally by distributed model classes.
# Configure via model config:
from transformers import AutoConfig

config = AutoConfig.from_pretrained(model_name)
config.tuning_mode = "ptune"
config.pre_seq_len = 16

from petals.models.llama.model import DistributedLlamaForSequenceClassification
model = DistributedLlamaForSequenceClassification.from_pretrained(model_name, config=config)

I/O Contract

Inputs

Name Type Required Description
config PretrainedConfig Yes Model config with tuning_mode and pre_seq_len fields
batch_size (get_prompt) int Yes Number of sequences in the current batch

Outputs

Name Type Description
prompt_embeddings nn.Embedding Trainable embedding layer of shape [pre_seq_len, hidden_size]
intermediate_prompt_embeddings Optional[nn.Embedding] Deep prompt embeddings [pre_seq_len * num_layers, hidden_size] (deep_ptune only)
get_prompt() returns Tuple[Optional[Tensor], Optional[Tensor]] (prompts, intermediate_prompts) expanded for batch

Usage Examples

Configuring Prompt Tuning for Classification

import torch
from transformers import AutoTokenizer, AutoConfig
from petals.models.llama.model import DistributedLlamaForSequenceClassification

model_name = "enoch/llama-65b-hf"

# Configure prompt tuning
config = AutoConfig.from_pretrained(model_name)
config.num_labels = 2
config.tuning_mode = "ptune"    # Shallow prompt tuning
config.pre_seq_len = 16         # 16 prefix tokens

# Load model with prompt tuning enabled
model = DistributedLlamaForSequenceClassification.from_pretrained(
    model_name, config=config
)

# Only prompt_embeddings and score head are trainable
trainable_params = [p for p in model.parameters() if p.requires_grad]
print(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")

Deep Prompt Tuning

config.tuning_mode = "deep_ptune"  # Prompts at every layer
config.pre_seq_len = 16

model = DistributedLlamaForSequenceClassification.from_pretrained(
    model_name, config=config
)

# Now both prompt_embeddings and intermediate_prompt_embeddings are trainable

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment