Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Pyro ppl Pyro PyroOptim Adam

From Leeroopedia


Field Value
Sources Pyro
Domains Optimization, Variational_Inference
Last Updated 2026-02-09 12:00 GMT
Type Wrapper Doc (wraps torch.optim.Adam through PyroOptim)

Overview

Adam in Pyro is a PyroOptim wrapper around torch.optim.Adam that provides dynamic parameter management for stochastic variational inference, automatically creating per-parameter optimizer instances as new parameters are discovered during model execution.

Description

The Pyro Adam optimizer is generated programmatically by wrapping torch.optim.Adam with the PyroOptim base class. This wrapper solves a fundamental mismatch between PyTorch's optimizer design (which expects all parameters to be known upfront) and Pyro's dynamic parameter model (where new parameters may be created during pyro.param() calls at any point during execution).

When Adam(optim_args) is called, it returns a PyroOptim instance that:

  • Creates a new torch.optim.Adam optimizer for each parameter the first time it is encountered
  • Applies the optimizer arguments (lr, betas, eps, weight_decay) from the provided optim_args dictionary
  • Supports per-parameter optimizer arguments via a callable that takes the parameter name and returns a dictionary
  • Optionally applies gradient clipping via clip_args (supporting both norm-based and value-based clipping)
  • Manages optimizer state for save/load serialization

The Adam optimizer is generated at import time in pyro/optim/pytorch_optimizers.py (L11--34) through a loop that wraps all torch.optim.Optimizer subclasses with PyroOptim. The PyroOptim base class is defined at pyro/optim/optim.py (L72--268).

Usage

Import Adam from pyro.optim and pass it to pyro.infer.SVI as the optimizer. This is the most commonly used optimizer for variational inference in Pyro.

Code Reference

Source Location

Repository
pyro-ppl/pyro
PyroOptim base
pyro/optim/optim.py L72--268
Adam generation
pyro/optim/pytorch_optimizers.py L11--34
External reference
torch.optim.Adam

Signature

def Adam(optim_args: dict, clip_args: Optional[dict] = None) -> PyroOptim:
    """
    Wraps torch.optim.Adam with PyroOptim.
    """
    return PyroOptim(torch.optim.Adam, optim_args, clip_args)

Import

from pyro.optim import Adam

I/O Contract

Inputs

Name Type Required Description
optim_args dict or callable Yes Dictionary of optimizer arguments (or callable returning such dict given parameter name). Supports keys: lr (default 0.001), betas (default (0.9, 0.999)), eps (default 1e-8), weight_decay (default 0), amsgrad (default False)
clip_args dict or callable or None No Optional gradient clipping arguments. Supports keys: clip_norm (max gradient norm), clip_value (max gradient value)

Outputs

Name Type Description
return value PyroOptim A PyroOptim instance wrapping torch.optim.Adam with dynamic parameter management; callable on parameter sets to perform optimization steps

Usage Examples

Basic Usage with SVI

from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO

optimizer = Adam({"lr": 0.001})
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

for step in range(1000):
    loss = svi.step(data)

Per-Parameter Learning Rates

from pyro.optim import Adam

def per_param_args(param_name):
    if "scale" in param_name:
        return {"lr": 0.01}
    return {"lr": 0.001}

optimizer = Adam(per_param_args)

With Gradient Clipping

from pyro.optim import Adam

optimizer = Adam({"lr": 0.001}, clip_args={"clip_norm": 10.0})

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment