Implementation:Bigscience workshop Petals PTuneMixin
Appearance
| Knowledge Sources | |
|---|---|
| Domains | NLP, Parameter_Efficient_Finetuning |
| Last Updated | 2026-02-09 14:00 GMT |
Overview
Concrete tool for initializing and managing trainable prompt embeddings in distributed Petals models for parameter-efficient fine-tuning.
Description
PTuneMixin is a mixin class that adds prompt tuning capabilities to distributed model classes. It provides two methods:
- init_prompts(config): Initializes prompt embedding layers based on the model's PTuneConfig settings. Creates prompt_embeddings (nn.Embedding) and optionally intermediate_prompt_embeddings for deep prompt tuning.
- get_prompt(batch_size): Returns the prompt tensor to prepend to input hidden states, properly expanded for the batch.
The mixin reads configuration from PTuneConfig, a dataclass with:
- pre_seq_len: Number of prefix prompt tokens (default 0, meaning no prompt tuning)
- tuning_mode: Either "ptune" (shallow) or "deep_ptune" (deep, all layers)
Usage
This mixin is automatically mixed into distributed model classes. Configure it by setting tuning_mode and pre_seq_len in the model config before loading. The prompt embeddings become part of the model's trainable parameters while all other weights remain frozen.
Code Reference
Source Location
- Repository: petals
- File: src/petals/client/ptune.py (L24-62)
Signature
@dataclasses.dataclass
class PTuneConfig:
pre_seq_len: int = 0
tuning_mode: Optional[str] = None # "ptune" or "deep_ptune"
class PTuneMixin:
_keys_to_ignore_on_load_missing = [r"(intermediate_)?prompt_embeddings\.weight$"]
def init_prompts(self, config: PretrainedConfig) -> None:
"""
Initialize prompt tuning embeddings based on config.
If tuning_mode is set and pre_seq_len > 0, creates:
- prompt_embeddings: nn.Embedding(pre_seq_len, hidden_size)
- intermediate_prompt_embeddings: nn.Embedding(pre_seq_len * num_layers, hidden_size)
(only if tuning_mode == "deep_ptune")
Args:
config: Model configuration with PTuneConfig fields
"""
def get_prompt(self, batch_size: int):
"""
Get prompt embeddings expanded for the given batch size.
Args:
batch_size: Number of sequences in the batch
Returns:
Tuple of (prompts, intermediate_prompts) tensors or (None, None)
prompts shape: [batch_size, pre_seq_len, hidden_size]
intermediate_prompts shape: [batch_size, pre_seq_len * num_layers, hidden_size] or None
"""
Import
# PTuneMixin is used internally by distributed model classes.
# Configure via model config:
from transformers import AutoConfig
config = AutoConfig.from_pretrained(model_name)
config.tuning_mode = "ptune"
config.pre_seq_len = 16
from petals.models.llama.model import DistributedLlamaForSequenceClassification
model = DistributedLlamaForSequenceClassification.from_pretrained(model_name, config=config)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| config | PretrainedConfig | Yes | Model config with tuning_mode and pre_seq_len fields |
| batch_size (get_prompt) | int | Yes | Number of sequences in the current batch |
Outputs
| Name | Type | Description |
|---|---|---|
| prompt_embeddings | nn.Embedding | Trainable embedding layer of shape [pre_seq_len, hidden_size] |
| intermediate_prompt_embeddings | Optional[nn.Embedding] | Deep prompt embeddings [pre_seq_len * num_layers, hidden_size] (deep_ptune only) |
| get_prompt() returns | Tuple[Optional[Tensor], Optional[Tensor]] | (prompts, intermediate_prompts) expanded for batch |
Usage Examples
Configuring Prompt Tuning for Classification
import torch
from transformers import AutoTokenizer, AutoConfig
from petals.models.llama.model import DistributedLlamaForSequenceClassification
model_name = "enoch/llama-65b-hf"
# Configure prompt tuning
config = AutoConfig.from_pretrained(model_name)
config.num_labels = 2
config.tuning_mode = "ptune" # Shallow prompt tuning
config.pre_seq_len = 16 # 16 prefix tokens
# Load model with prompt tuning enabled
model = DistributedLlamaForSequenceClassification.from_pretrained(
model_name, config=config
)
# Only prompt_embeddings and score head are trainable
trainable_params = [p for p in model.parameters() if p.requires_grad]
print(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
Deep Prompt Tuning
config.tuning_mode = "deep_ptune" # Prompts at every layer
config.pre_seq_len = 16
model = DistributedLlamaForSequenceClassification.from_pretrained(
model_name, config=config
)
# Now both prompt_embeddings and intermediate_prompt_embeddings are trainable
Related Pages
Implements Principle
Requires Environment
Uses Heuristic
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment