Implementation:Mlfoundations Open flamingo AdamW cosine schedule
Appearance
Overview
Wrapper pattern combining PyTorch AdamW optimizer with HuggingFace learning rate schedulers for OpenFlamingo training.
Description
This is a Wrapper Doc. The training script creates AdamW with two parameter groups:
- Params with weight decay: gated_cross_attn params when
fsdp_use_orig_paramsisTrue - Params without weight decay: all other trainable params
The scheduler is selected from constant, linear, or cosine variants from HuggingFace transformers, all with warmup steps.
Usage
After FSDP/DDP model wrapping, before the training loop.
Code Reference
- Source
- Repository https://github.com/mlfoundations/open_flamingo, File:
open_flamingo/train/train.pyLines L383-454
- Signature (composite pattern)
# Optimizer
optimizer = torch.optim.AdamW(
params, # Filtered to requires_grad, excludes optimizer-excluded params
lr=args.learning_rate, # default 1e-4
weight_decay=args.weight_decay, # default 0.1
)
# Scheduler (one of three variants)
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=args.warmup_steps, # default 5000
num_training_steps=total_training_steps,
)
# OR get_constant_schedule_with_warmup / get_linear_schedule_with_warmup
- Import
import torch.optim
from transformers import (
get_cosine_schedule_with_warmup,
get_constant_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
I/O Contract
| Direction | Name | Type | Default | Description |
|---|---|---|---|---|
| Input | model parameters | filtered params | — | Parameters with requires_grad=True, excluding optimizer-excluded params
|
| Input | learning_rate | float | 1e-4 | Peak learning rate after warmup |
| Input | weight_decay | float | 0.1 | Weight decay coefficient applied to gated cross-attention params |
| Input | warmup_steps | int | 5000 | Number of linear warmup steps |
| Input | lr_scheduler type | str | "constant" | One of "constant", "linear", or "cosine"
|
| Input | total_training_steps | int | — | Total number of training steps for decay schedule computation |
| Input | gradient_accumulation_steps | int | 1 | Number of micro-batches before an optimizer step |
| Output | optimizer | AdamW | — | Configured AdamW optimizer instance with parameter groups |
| Output | lr_scheduler | LRScheduler | — | HuggingFace learning rate scheduler instance |
Usage Examples
import torch
from transformers import get_cosine_schedule_with_warmup
# Separate parameter groups for weight decay
params_with_decay = [
p for n, p in model.named_parameters()
if p.requires_grad and "gated_cross_attn" in n
]
params_without_decay = [
p for n, p in model.named_parameters()
if p.requires_grad and "gated_cross_attn" not in n
]
optimizer = torch.optim.AdamW(
[
{"params": params_with_decay, "weight_decay": 0.1},
{"params": params_without_decay, "weight_decay": 0.0},
],
lr=1e-4,
)
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=5000,
num_training_steps=100000,
)
# Training loop
for step, batch in enumerate(dataloader):
loss = model(batch)
loss.backward()
if (step + 1) % gradient_accumulation_steps == 0:
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
Related Pages
Principle:Mlfoundations_Open_flamingo_Optimizer_And_Scheduler_Configuration
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment