Implementation:Pyro ppl Pyro ClippedAdam Optimizer
| Field | Value |
|---|---|
| Sources | Pyro |
| Domains | Optimization, Deep_Learning |
| Last Updated | 2026-02-09 12:00 GMT |
Overview
ClippedAdam is a custom Adam optimizer in Pyro with built-in element-wise gradient clipping and per-step learning rate decay, designed for stable training of deep generative models and complex probabilistic programs.
Description
The ClippedAdam class is a custom PyTorch optimizer that extends torch.optim.optimizer.Optimizer and implements the Adam algorithm with two additional features integrated directly into the optimization step:
- Element-wise gradient clipping: Before computing moment estimates, each gradient element is clamped to the range [-clip_norm, clip_norm]. This prevents any individual gradient component from causing disproportionately large parameter updates.
- Per-step learning rate decay: At each optimization step, the learning rate is multiplied by the decay factor
lrd, implementing the schedulelr_t = lr_0 * lrd^t. This provides a simple monotonically decreasing learning rate without requiring an external scheduler.
The raw ClippedAdam optimizer class is defined at pyro/optim/clipped_adam.py (L11--110). A factory function at pyro/optim/optim.py (L277--281) wraps it with PyroOptim to provide Pyro's dynamic parameter management. The wrapped version is the one exposed to users via from pyro.optim import ClippedAdam.
The optimizer also supports centered variance as an optional feature, where the second moment estimate uses the centered gradient (g - m) instead of the raw gradient g, which can improve stability in some settings.
Usage
Import ClippedAdam from pyro.optim and pass it to pyro.infer.SVI. Use it as a drop-in replacement for Adam when training is unstable or produces NaN losses.
Code Reference
Source Location
- Repository
pyro-ppl/pyro- File
pyro/optim/clipped_adam.pyL11--110- Factory
pyro/optim/optim.pyL277--281
Signature
# Raw optimizer class
class ClippedAdam(Optimizer):
def __init__(
self,
params,
lr: float = 1e-3,
betas: Tuple = (0.9, 0.999),
eps: float = 1e-8,
weight_decay=0,
clip_norm: float = 10.0,
lrd: float = 1.0,
centered_variance: bool = False,
):
# PyroOptim factory (user-facing)
def ClippedAdam(optim_args: dict) -> PyroOptim:
return PyroOptim(pt_ClippedAdam, optim_args)
Import
from pyro.optim import ClippedAdam
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
optim_args |
dict | Yes | Dictionary of optimizer arguments passed to the underlying ClippedAdam constructor |
optim_args keys:
| Key | Type | Default | Description |
|---|---|---|---|
lr |
float | 1e-3 | Initial learning rate |
betas |
tuple(float, float) | (0.9, 0.999) | Coefficients for computing running averages of gradient and its square |
eps |
float | 1e-8 | Term added to denominator for numerical stability |
weight_decay |
float | 0 | Weight decay (L2 penalty) |
clip_norm |
float | 10.0 | Maximum absolute value for element-wise gradient clipping |
lrd |
float | 1.0 | Learning rate decay multiplier applied per step (set < 1.0 to enable) |
centered_variance |
bool | False | Use centered variance for second moment estimate |
Outputs
| Name | Type | Description |
|---|---|---|
| return value | PyroOptim | A PyroOptim instance wrapping ClippedAdam with dynamic parameter management, gradient clipping, and learning rate decay |
Usage Examples
Basic Usage
from pyro.optim import ClippedAdam
from pyro.infer import SVI, Trace_ELBO
optimizer = ClippedAdam({"lr": 0.005, "clip_norm": 10.0, "lrd": 0.9999})
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
for step in range(2000):
loss = svi.step(data)
Deep Generative Model Training
from pyro.optim import ClippedAdam
from pyro.infer import SVI, Trace_ELBO
# More aggressive clipping and decay for a VAE
optimizer = ClippedAdam({
"lr": 0.001,
"betas": (0.9, 0.999),
"clip_norm": 5.0,
"lrd": 0.99999,
"weight_decay": 1e-5,
})
svi = SVI(vae_model, vae_guide, optimizer, loss=Trace_ELBO(num_particles=5))
for epoch in range(100):
for batch in dataloader:
loss = svi.step(batch)