Implementation:Microsoft DeepSpeedExamples BertAdam Optimizer
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Learning Rate Scheduling |
| Last Updated | 2026-02-07 12:00 GMT |
Overview
A custom PyTorch optimizer implementing the BERT variant of Adam with decoupled weight decay and six configurable learning rate warmup/decay schedules.
Description
BertAdam extends torch.optim.Optimizer to implement the Adam algorithm with a weight decay fix specifically designed for BERT training. Unlike standard Adam with L2 regularization, BertAdam applies weight decay directly to the parameter values rather than incorporating it into the gradient moments, avoiding undesirable interactions between the decay and the adaptive learning rate mechanism. This approach is equivalent to AdamW (decoupled weight decay).
The module provides six learning rate schedule functions: warmup_cosine (cosine annealing after warmup), warmup_constant (constant LR after warmup), warmup_linear (linear decay after warmup), warmup_linear_decay_exp (exponential decay after linear warmup), warmup_exp_decay_exp (exponential warmup followed by exponential decay), and warmup_exp_decay_poly (polynomial decay after exponential warmup). Each schedule takes a progress fraction and warmup proportion as input, and the selected schedule is applied at every optimization step to compute the effective learning rate.
The optimizer supports gradient clipping via the max_grad_norm parameter, maintains exponential moving averages for first and second moments (next_m and next_v), and intentionally omits bias correction (unlike standard Adam). The get_lr method returns the current scheduled learning rate for all parameter groups, useful for logging and monitoring.
Usage
Use BertAdam as the optimizer for BERT pretraining and fine-tuning tasks. It should be configured with the total training steps (t_total) and warmup fraction to enable learning rate scheduling. Select the appropriate schedule based on the desired learning rate trajectory during training.
Code Reference
Source Location
- Repository: Microsoft_DeepSpeedExamples
- File: training/bing_bert/pytorch_pretrained_bert/optimization.py
- Lines: 1-237
Signature
def warmup_cosine(x, warmup=0.002):
def warmup_constant(x, warmup=0.002):
def warmup_linear(x, warmup=0.002):
def warmup_linear_decay_exp(global_step, decay_rate, decay_steps, total_steps, warmup=0.002):
def warmup_exp_decay_exp(global_step, decay_rate, decay_steps, total_steps, warmup=0.002, degree=2.0):
def warmup_exp_decay_poly(global_step, total_steps, warmup=0.002, warm_degree=1.5, degree=2.0):
SCHEDULES = {
'warmup_cosine': warmup_cosine,
'warmup_constant': warmup_constant,
'warmup_linear': warmup_linear,
'warmup_linear_decay_exp': warmup_linear_decay_exp,
'warmup_exp_decay_poly': warmup_exp_decay_poly,
'warmup_exp_decay_exp': warmup_exp_decay_exp
}
class BertAdam(Optimizer):
def __init__(self, params, lr=required, warmup=-1, t_total=-1,
schedule='warmup_linear', b1=0.9, b2=0.999, e=1e-6,
weight_decay=0.01, max_grad_norm=1.0):
def get_lr(self):
def step(self, closure=None):
Import
from pytorch_pretrained_bert.optimization import BertAdam, SCHEDULES
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| params | iterable | Yes | Iterable of parameters to optimize or dicts defining parameter groups |
| lr | float | Yes | Learning rate (must be >= 0.0) |
| warmup | float | No | Proportion of t_total for warmup; -1 means no warmup. Default: -1 |
| t_total | int | No | Total training steps for LR scheduling; -1 means constant LR. Default: -1 |
| schedule | str | No | Name of warmup schedule function. Default: 'warmup_linear' |
| b1 | float | No | Adam beta1. Default: 0.9 |
| b2 | float | No | Adam beta2. Default: 0.999 |
| e | float | No | Adam epsilon. Default: 1e-6 |
| weight_decay | float | No | Decoupled weight decay coefficient. Default: 0.01 |
| max_grad_norm | float | No | Maximum gradient norm for clipping; -1 disables. Default: 1.0 |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | Tensor or None | Loss value if a closure is provided, otherwise None |
| lr | list[float] | Current learning rates for all parameter groups (via get_lr()) |
Usage Examples
from pytorch_pretrained_bert.optimization import BertAdam
# Configure optimizer with warmup linear schedule
optimizer = BertAdam(
model.parameters(),
lr=2e-5,
warmup=0.1,
t_total=100000,
schedule='warmup_linear',
weight_decay=0.01,
max_grad_norm=1.0
)
# Training loop
for step, batch in enumerate(dataloader):
loss = model(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Monitor learning rate
current_lr = optimizer.get_lr()