Principle:Bigscience workshop Petals Prompt Tuning
| Knowledge Sources | |
|---|---|
| Domains | NLP, Parameter_Efficient_Finetuning, Transfer_Learning |
| Last Updated | 2026-02-09 14:00 GMT |
Overview
A parameter-efficient fine-tuning technique that prepends trainable continuous embeddings (soft prompts) to the input sequence, enabling task adaptation of large language models without modifying the frozen model weights.
Description
Prompt Tuning addresses the challenge of adapting large language models to specific tasks when the model weights are frozen (as in Petals' distributed setting where model weights live on remote servers). Instead of modifying model parameters, prompt tuning introduces a small set of trainable prefix embeddings (soft prompts) that are prepended to the input hidden states before passing through the transformer blocks.
Two modes in Petals:
- ptune (shallow): Trainable prompt embeddings are prepended only at the input layer. The prompt_embeddings parameter is an nn.Embedding of shape [pre_seq_len, hidden_size].
- deep_ptune (deep): Additional intermediate_prompt_embeddings are injected at every transformer block layer, providing more expressive adaptation. This is analogous to P-Tuning v2.
Key advantage for distributed inference: Only the prompt embeddings and task-specific head (e.g., classification head) need to be trained locally. All transformer block computation still happens on remote servers, and gradients flow through the distributed autograd mechanism.
Usage
Use this principle when you need to adapt a distributed large language model to a specific downstream task (classification, dialogue, etc.) without modifying the model weights. It is particularly suited to the Petals setting where model weights are hosted remotely and cannot be directly fine-tuned.
Theoretical Basis
Soft prompt formulation:
Given input token embeddings and trainable prompt embeddings , the effective input becomes:
For deep prompt tuning, at each layer l:
where are layer-specific prompt embeddings.
Training: Only (and for deep mode) are optimized via gradient descent. The transformer blocks remain frozen.
# Abstract prompt tuning algorithm
prompt_embeddings = nn.Embedding(pre_seq_len, hidden_size) # Trainable
# For deep mode:
intermediate_prompts = nn.Embedding(pre_seq_len * num_layers, hidden_size) # Trainable
# Forward pass
prompts = prompt_embeddings(prompt_ids) # [batch, pre_seq_len, hidden]
input_embs = word_embeddings(input_ids) # [batch, seq_len, hidden]
hidden = concat(prompts, input_embs) # [batch, pre_seq_len + seq_len, hidden]
output = transformer_blocks(hidden) # Remote computation