Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft DeepSpeedExamples BertAdam Optimizer

From Leeroopedia


Knowledge Sources
Domains Optimization, Learning Rate Scheduling
Last Updated 2026-02-07 12:00 GMT

Overview

A custom PyTorch optimizer implementing the BERT variant of Adam with decoupled weight decay and six configurable learning rate warmup/decay schedules.

Description

BertAdam extends torch.optim.Optimizer to implement the Adam algorithm with a weight decay fix specifically designed for BERT training. Unlike standard Adam with L2 regularization, BertAdam applies weight decay directly to the parameter values rather than incorporating it into the gradient moments, avoiding undesirable interactions between the decay and the adaptive learning rate mechanism. This approach is equivalent to AdamW (decoupled weight decay).

The module provides six learning rate schedule functions: warmup_cosine (cosine annealing after warmup), warmup_constant (constant LR after warmup), warmup_linear (linear decay after warmup), warmup_linear_decay_exp (exponential decay after linear warmup), warmup_exp_decay_exp (exponential warmup followed by exponential decay), and warmup_exp_decay_poly (polynomial decay after exponential warmup). Each schedule takes a progress fraction and warmup proportion as input, and the selected schedule is applied at every optimization step to compute the effective learning rate.

The optimizer supports gradient clipping via the max_grad_norm parameter, maintains exponential moving averages for first and second moments (next_m and next_v), and intentionally omits bias correction (unlike standard Adam). The get_lr method returns the current scheduled learning rate for all parameter groups, useful for logging and monitoring.

Usage

Use BertAdam as the optimizer for BERT pretraining and fine-tuning tasks. It should be configured with the total training steps (t_total) and warmup fraction to enable learning rate scheduling. Select the appropriate schedule based on the desired learning rate trajectory during training.

Code Reference

Source Location

Signature

def warmup_cosine(x, warmup=0.002):
def warmup_constant(x, warmup=0.002):
def warmup_linear(x, warmup=0.002):
def warmup_linear_decay_exp(global_step, decay_rate, decay_steps, total_steps, warmup=0.002):
def warmup_exp_decay_exp(global_step, decay_rate, decay_steps, total_steps, warmup=0.002, degree=2.0):
def warmup_exp_decay_poly(global_step, total_steps, warmup=0.002, warm_degree=1.5, degree=2.0):

SCHEDULES = {
    'warmup_cosine': warmup_cosine,
    'warmup_constant': warmup_constant,
    'warmup_linear': warmup_linear,
    'warmup_linear_decay_exp': warmup_linear_decay_exp,
    'warmup_exp_decay_poly': warmup_exp_decay_poly,
    'warmup_exp_decay_exp': warmup_exp_decay_exp
}

class BertAdam(Optimizer):
    def __init__(self, params, lr=required, warmup=-1, t_total=-1,
                 schedule='warmup_linear', b1=0.9, b2=0.999, e=1e-6,
                 weight_decay=0.01, max_grad_norm=1.0):
    def get_lr(self):
    def step(self, closure=None):

Import

from pytorch_pretrained_bert.optimization import BertAdam, SCHEDULES

I/O Contract

Inputs

Name Type Required Description
params iterable Yes Iterable of parameters to optimize or dicts defining parameter groups
lr float Yes Learning rate (must be >= 0.0)
warmup float No Proportion of t_total for warmup; -1 means no warmup. Default: -1
t_total int No Total training steps for LR scheduling; -1 means constant LR. Default: -1
schedule str No Name of warmup schedule function. Default: 'warmup_linear'
b1 float No Adam beta1. Default: 0.9
b2 float No Adam beta2. Default: 0.999
e float No Adam epsilon. Default: 1e-6
weight_decay float No Decoupled weight decay coefficient. Default: 0.01
max_grad_norm float No Maximum gradient norm for clipping; -1 disables. Default: 1.0

Outputs

Name Type Description
loss Tensor or None Loss value if a closure is provided, otherwise None
lr list[float] Current learning rates for all parameter groups (via get_lr())

Usage Examples

from pytorch_pretrained_bert.optimization import BertAdam

# Configure optimizer with warmup linear schedule
optimizer = BertAdam(
    model.parameters(),
    lr=2e-5,
    warmup=0.1,
    t_total=100000,
    schedule='warmup_linear',
    weight_decay=0.01,
    max_grad_norm=1.0
)

# Training loop
for step, batch in enumerate(dataloader):
    loss = model(batch)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    # Monitor learning rate
    current_lr = optimizer.get_lr()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment