Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Bigscience workshop Petals Prompt Embeddings Float32 Precision

From Leeroopedia




Knowledge Sources
Domains Optimization, Prompt_Tuning
Last Updated 2026-02-09 13:00 GMT

Overview

Prompt tuning embeddings and their optimizer states are stored in float32 precision to increase p-tuning quality, even when the model runs in float16/bfloat16.

Description

In Petals prompt tuning (both standard and deep p-tuning), the learnable prompt embeddings are explicitly initialized as float32 tensors via `nn.Embedding(..., dtype=torch.float32)`. This ensures that the small trainable parameters accumulate gradients and optimizer states (e.g., AdamW momentum) at full precision. The prompts are cast to the model's dtype only when injected into the forward pass.

Usage

Applied automatically when using `tuning_mode="ptune"` or `tuning_mode="deep_ptune"`. The float32 precision applies to `prompt_embeddings` and `intermediate_prompt_embeddings` (for deep p-tune).

The Insight (Rule of Thumb)

  • Action: Keep prompt embeddings in float32, even if the model uses float16/bfloat16. They are cast at forward time.
  • Value: float32 for prompt parameters; model dtype (float16/bfloat16) for computation.
  • Trade-off: 2x memory for prompt parameters (negligible since prompts are tiny vs model), but significantly better training stability and quality.

Reasoning

Prompt embeddings have very few parameters (typically `pre_seq_len * hidden_size`, e.g., 16 * 4096 = 65K params) but they are the only trainable parameters in p-tuning. Using float16 for such small parameter sets leads to gradient underflow and poor convergence. The memory overhead of float32 is negligible (a few hundred KB) compared to the model's billions of parameters.

Code Evidence

From `src/petals/client/ptune.py:31-32`:

# Prompt embeddings and their optimizer stats are kept in float32 to increase ptune quality
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)

Cast at forward time from `src/petals/client/ptune.py:61-62`:

dtype = self.word_embeddings.weight.dtype
return prompts.to(dtype), intermediate_prompts.to(dtype)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment