Implementation:Bigscience workshop Petals AdamW And Scheduler
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, Optimization, Training |
| Last Updated | 2026-02-09 14:00 GMT |
Overview
Concrete tool for configuring AdamW optimizer and learning rate scheduler for prompt tuning training, using PyTorch and HuggingFace Transformers as used in Petals training workflows.
Description
This wrapper documents the standard AdamW + scheduler setup as used in Petals' prompt tuning examples. The setup is straightforward but has Petals-specific considerations:
- Parameter filtering: Only requires_grad=True parameters are passed to the optimizer
- Higher learning rate: Prompt tuning benefits from higher LR (1e-3) than full fine-tuning (1e-5)
- Minimal weight decay: Often set to 0 for prompt embeddings since they are small and not prone to overfitting in the same way as full model weights
Usage
Use these external APIs after loading a distributed model with prompt tuning enabled. Filter model parameters to include only trainable ones.
Code Reference
Source Location
- Repository: External (torch, transformers)
- File: External: torch.optim.AdamW, transformers.get_scheduler
Signature
class torch.optim.AdamW:
def __init__(
self,
params,
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 1e-2,
amsgrad: bool = False,
):
"""AdamW optimizer with decoupled weight decay."""
def transformers.get_scheduler(
name: str,
optimizer: torch.optim.Optimizer,
num_warmup_steps: int = None,
num_training_steps: int = None,
) -> torch.optim.lr_scheduler.LambdaLR:
"""
Get a learning rate scheduler.
Args:
name: "linear", "cosine", "constant", etc.
"""
Import
from torch.optim import AdamW
from transformers import get_scheduler
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| params | Iterable[Parameter] | Yes | Trainable parameters (prompt_embeddings, score head) |
| lr | float | No | Learning rate (default 1e-3) |
| weight_decay | float | No | Weight decay coefficient |
| name | str | Yes (scheduler) | Scheduler type ("linear", "cosine", etc.) |
| num_warmup_steps | int | Yes (scheduler) | Steps before full learning rate |
| num_training_steps | int | Yes (scheduler) | Total training steps |
Outputs
| Name | Type | Description |
|---|---|---|
| optimizer | AdamW | Configured optimizer for gradient updates |
| scheduler | LambdaLR | Learning rate scheduler |
Usage Examples
Prompt Tuning Optimizer Setup
from torch.optim import AdamW
from transformers import get_scheduler
# Only optimize trainable parameters
trainable_params = [p for p in model.parameters() if p.requires_grad]
optimizer = AdamW(trainable_params, lr=1e-3, weight_decay=0.0)
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
scheduler = get_scheduler(
"linear",
optimizer=optimizer,
num_warmup_steps=100,
num_training_steps=num_training_steps,
)
# Training loop
for epoch in range(num_epochs):
for batch in train_dataloader:
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
Related Pages
Implements Principle
Requires Environment
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment