Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Pyro ppl Pyro ClippedAdam Optimizer

From Leeroopedia


Field Value
Sources Pyro
Domains Optimization, Deep_Learning
Last Updated 2026-02-09 12:00 GMT

Overview

ClippedAdam is a custom Adam optimizer in Pyro with built-in element-wise gradient clipping and per-step learning rate decay, designed for stable training of deep generative models and complex probabilistic programs.

Description

The ClippedAdam class is a custom PyTorch optimizer that extends torch.optim.optimizer.Optimizer and implements the Adam algorithm with two additional features integrated directly into the optimization step:

  • Element-wise gradient clipping: Before computing moment estimates, each gradient element is clamped to the range [-clip_norm, clip_norm]. This prevents any individual gradient component from causing disproportionately large parameter updates.
  • Per-step learning rate decay: At each optimization step, the learning rate is multiplied by the decay factor lrd, implementing the schedule lr_t = lr_0 * lrd^t. This provides a simple monotonically decreasing learning rate without requiring an external scheduler.

The raw ClippedAdam optimizer class is defined at pyro/optim/clipped_adam.py (L11--110). A factory function at pyro/optim/optim.py (L277--281) wraps it with PyroOptim to provide Pyro's dynamic parameter management. The wrapped version is the one exposed to users via from pyro.optim import ClippedAdam.

The optimizer also supports centered variance as an optional feature, where the second moment estimate uses the centered gradient (g - m) instead of the raw gradient g, which can improve stability in some settings.

Usage

Import ClippedAdam from pyro.optim and pass it to pyro.infer.SVI. Use it as a drop-in replacement for Adam when training is unstable or produces NaN losses.

Code Reference

Source Location

Repository
pyro-ppl/pyro
File
pyro/optim/clipped_adam.py L11--110
Factory
pyro/optim/optim.py L277--281

Signature

# Raw optimizer class
class ClippedAdam(Optimizer):
    def __init__(
        self,
        params,
        lr: float = 1e-3,
        betas: Tuple = (0.9, 0.999),
        eps: float = 1e-8,
        weight_decay=0,
        clip_norm: float = 10.0,
        lrd: float = 1.0,
        centered_variance: bool = False,
    ):

# PyroOptim factory (user-facing)
def ClippedAdam(optim_args: dict) -> PyroOptim:
    return PyroOptim(pt_ClippedAdam, optim_args)

Import

from pyro.optim import ClippedAdam

I/O Contract

Inputs

Name Type Required Description
optim_args dict Yes Dictionary of optimizer arguments passed to the underlying ClippedAdam constructor

optim_args keys:

Key Type Default Description
lr float 1e-3 Initial learning rate
betas tuple(float, float) (0.9, 0.999) Coefficients for computing running averages of gradient and its square
eps float 1e-8 Term added to denominator for numerical stability
weight_decay float 0 Weight decay (L2 penalty)
clip_norm float 10.0 Maximum absolute value for element-wise gradient clipping
lrd float 1.0 Learning rate decay multiplier applied per step (set < 1.0 to enable)
centered_variance bool False Use centered variance for second moment estimate

Outputs

Name Type Description
return value PyroOptim A PyroOptim instance wrapping ClippedAdam with dynamic parameter management, gradient clipping, and learning rate decay

Usage Examples

Basic Usage

from pyro.optim import ClippedAdam
from pyro.infer import SVI, Trace_ELBO

optimizer = ClippedAdam({"lr": 0.005, "clip_norm": 10.0, "lrd": 0.9999})
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

for step in range(2000):
    loss = svi.step(data)

Deep Generative Model Training

from pyro.optim import ClippedAdam
from pyro.infer import SVI, Trace_ELBO

# More aggressive clipping and decay for a VAE
optimizer = ClippedAdam({
    "lr": 0.001,
    "betas": (0.9, 0.999),
    "clip_norm": 5.0,
    "lrd": 0.99999,
    "weight_decay": 1e-5,
})

svi = SVI(vae_model, vae_guide, optimizer, loss=Trace_ELBO(num_particles=5))

for epoch in range(100):
    for batch in dataloader:
        loss = svi.step(batch)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment