Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mlfoundations Open flamingo AdamW cosine schedule

From Leeroopedia


Template:Metadata

Overview

Wrapper pattern combining PyTorch AdamW optimizer with HuggingFace learning rate schedulers for OpenFlamingo training.

Description

This is a Wrapper Doc. The training script creates AdamW with two parameter groups:

  1. Params with weight decay: gated_cross_attn params when fsdp_use_orig_params is True
  2. Params without weight decay: all other trainable params

The scheduler is selected from constant, linear, or cosine variants from HuggingFace transformers, all with warmup steps.

Usage

After FSDP/DDP model wrapping, before the training loop.

Code Reference

Source
Repository https://github.com/mlfoundations/open_flamingo, File: open_flamingo/train/train.py Lines L383-454
Signature (composite pattern)
# Optimizer
optimizer = torch.optim.AdamW(
    params,  # Filtered to requires_grad, excludes optimizer-excluded params
    lr=args.learning_rate,  # default 1e-4
    weight_decay=args.weight_decay,  # default 0.1
)

# Scheduler (one of three variants)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=args.warmup_steps,  # default 5000
    num_training_steps=total_training_steps,
)
# OR get_constant_schedule_with_warmup / get_linear_schedule_with_warmup
Import
import torch.optim
from transformers import (
    get_cosine_schedule_with_warmup,
    get_constant_schedule_with_warmup,
    get_linear_schedule_with_warmup,
)

I/O Contract

Direction Name Type Default Description
Input model parameters filtered params Parameters with requires_grad=True, excluding optimizer-excluded params
Input learning_rate float 1e-4 Peak learning rate after warmup
Input weight_decay float 0.1 Weight decay coefficient applied to gated cross-attention params
Input warmup_steps int 5000 Number of linear warmup steps
Input lr_scheduler type str "constant" One of "constant", "linear", or "cosine"
Input total_training_steps int Total number of training steps for decay schedule computation
Input gradient_accumulation_steps int 1 Number of micro-batches before an optimizer step
Output optimizer AdamW Configured AdamW optimizer instance with parameter groups
Output lr_scheduler LRScheduler HuggingFace learning rate scheduler instance

Usage Examples

import torch
from transformers import get_cosine_schedule_with_warmup

# Separate parameter groups for weight decay
params_with_decay = [
    p for n, p in model.named_parameters()
    if p.requires_grad and "gated_cross_attn" in n
]
params_without_decay = [
    p for n, p in model.named_parameters()
    if p.requires_grad and "gated_cross_attn" not in n
]

optimizer = torch.optim.AdamW(
    [
        {"params": params_with_decay, "weight_decay": 0.1},
        {"params": params_without_decay, "weight_decay": 0.0},
    ],
    lr=1e-4,
)

lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=5000,
    num_training_steps=100000,
)

# Training loop
for step, batch in enumerate(dataloader):
    loss = model(batch)
    loss.backward()
    if (step + 1) % gradient_accumulation_steps == 0:
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

Related Pages

Principle:Mlfoundations_Open_flamingo_Optimizer_And_Scheduler_Configuration

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment