Implementation:Pyro ppl Pyro PyroOptim Adam
| Field | Value |
|---|---|
| Sources | Pyro |
| Domains | Optimization, Variational_Inference |
| Last Updated | 2026-02-09 12:00 GMT |
| Type | Wrapper Doc (wraps torch.optim.Adam through PyroOptim)
|
Overview
Adam in Pyro is a PyroOptim wrapper around torch.optim.Adam that provides dynamic parameter management for stochastic variational inference, automatically creating per-parameter optimizer instances as new parameters are discovered during model execution.
Description
The Pyro Adam optimizer is generated programmatically by wrapping torch.optim.Adam with the PyroOptim base class. This wrapper solves a fundamental mismatch between PyTorch's optimizer design (which expects all parameters to be known upfront) and Pyro's dynamic parameter model (where new parameters may be created during pyro.param() calls at any point during execution).
When Adam(optim_args) is called, it returns a PyroOptim instance that:
- Creates a new
torch.optim.Adamoptimizer for each parameter the first time it is encountered - Applies the optimizer arguments (
lr,betas,eps,weight_decay) from the providedoptim_argsdictionary - Supports per-parameter optimizer arguments via a callable that takes the parameter name and returns a dictionary
- Optionally applies gradient clipping via
clip_args(supporting both norm-based and value-based clipping) - Manages optimizer state for save/load serialization
The Adam optimizer is generated at import time in pyro/optim/pytorch_optimizers.py (L11--34) through a loop that wraps all torch.optim.Optimizer subclasses with PyroOptim. The PyroOptim base class is defined at pyro/optim/optim.py (L72--268).
Usage
Import Adam from pyro.optim and pass it to pyro.infer.SVI as the optimizer. This is the most commonly used optimizer for variational inference in Pyro.
Code Reference
Source Location
- Repository
pyro-ppl/pyro- PyroOptim base
pyro/optim/optim.pyL72--268- Adam generation
pyro/optim/pytorch_optimizers.pyL11--34- External reference
torch.optim.Adam
Signature
def Adam(optim_args: dict, clip_args: Optional[dict] = None) -> PyroOptim:
"""
Wraps torch.optim.Adam with PyroOptim.
"""
return PyroOptim(torch.optim.Adam, optim_args, clip_args)
Import
from pyro.optim import Adam
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
optim_args |
dict or callable | Yes | Dictionary of optimizer arguments (or callable returning such dict given parameter name). Supports keys: lr (default 0.001), betas (default (0.9, 0.999)), eps (default 1e-8), weight_decay (default 0), amsgrad (default False)
|
clip_args |
dict or callable or None | No | Optional gradient clipping arguments. Supports keys: clip_norm (max gradient norm), clip_value (max gradient value)
|
Outputs
| Name | Type | Description |
|---|---|---|
| return value | PyroOptim | A PyroOptim instance wrapping torch.optim.Adam with dynamic parameter management; callable on parameter sets to perform optimization steps
|
Usage Examples
Basic Usage with SVI
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
optimizer = Adam({"lr": 0.001})
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
for step in range(1000):
loss = svi.step(data)
Per-Parameter Learning Rates
from pyro.optim import Adam
def per_param_args(param_name):
if "scale" in param_name:
return {"lr": 0.01}
return {"lr": 0.001}
optimizer = Adam(per_param_args)
With Gradient Clipping
from pyro.optim import Adam
optimizer = Adam({"lr": 0.001}, clip_args={"clip_norm": 10.0})