Heuristic:Bigscience workshop Petals Prompt Embeddings Float32 Precision
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Prompt_Tuning |
| Last Updated | 2026-02-09 13:00 GMT |
Overview
Prompt tuning embeddings and their optimizer states are stored in float32 precision to increase p-tuning quality, even when the model runs in float16/bfloat16.
Description
In Petals prompt tuning (both standard and deep p-tuning), the learnable prompt embeddings are explicitly initialized as float32 tensors via `nn.Embedding(..., dtype=torch.float32)`. This ensures that the small trainable parameters accumulate gradients and optimizer states (e.g., AdamW momentum) at full precision. The prompts are cast to the model's dtype only when injected into the forward pass.
Usage
Applied automatically when using `tuning_mode="ptune"` or `tuning_mode="deep_ptune"`. The float32 precision applies to `prompt_embeddings` and `intermediate_prompt_embeddings` (for deep p-tune).
The Insight (Rule of Thumb)
- Action: Keep prompt embeddings in float32, even if the model uses float16/bfloat16. They are cast at forward time.
- Value: float32 for prompt parameters; model dtype (float16/bfloat16) for computation.
- Trade-off: 2x memory for prompt parameters (negligible since prompts are tiny vs model), but significantly better training stability and quality.
Reasoning
Prompt embeddings have very few parameters (typically `pre_seq_len * hidden_size`, e.g., 16 * 4096 = 65K params) but they are the only trainable parameters in p-tuning. Using float16 for such small parameter sets leads to gradient underflow and poor convergence. The memory overhead of float32 is negligible (a few hundred KB) compared to the model's billions of parameters.
Code Evidence
From `src/petals/client/ptune.py:31-32`:
# Prompt embeddings and their optimizer stats are kept in float32 to increase ptune quality
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
Cast at forward time from `src/petals/client/ptune.py:61-62`:
dtype = self.word_embeddings.weight.dtype
return prompts.to(dtype), intermediate_prompts.to(dtype)