Implementation:Pyro ppl Pyro MultiOptimizer
| Property | Value |
|---|---|
| Module | pyro.optim.multi
|
| Source | pyro/optim/multi.py |
| Lines | 169 |
| Classes | MultiOptimizer, PyroMultiOptimizer, TorchMultiOptimizer, MixedMultiOptimizer, Newton
|
| Dependencies | torch, pyro.ops.newton, pyro.optim.optim
|
Overview
This module provides a framework for higher-order optimizers in Pyro. Unlike standard first-order optimizers that call loss.backward() externally and then update parameters, higher-order optimizers use torch.autograd.grad internally and may require multiple backward passes through the loss.
The MultiOptimizer base class defines an interface where the step method receives both the loss tensor and a dictionary of parameters, triggering backpropagation internally. This enables second-order methods like Newton's method and supports mixing different optimization strategies for different parameter groups.
Code Reference
Class: MultiOptimizer (Base)
Methods:
step(loss, params): In-place optimization step. Computes updated values viaget_stepand copies them (detached) into the original parameters.get_step(loss, params): Abstract. Returns a dict of updated parameter values (preserving differentiability).
Class: PyroMultiOptimizer
Wraps a PyroOptim in the MultiOptimizer interface. Uses torch.autograd.grad(loss, values, create_graph=True) to compute gradients, then assigns them as .grad attributes and calls the underlying optimizer.
Class: TorchMultiOptimizer
Wraps a PyTorch Optimizer class in the MultiOptimizer interface. First wraps the optimizer in PyroOptim, then delegates to PyroMultiOptimizer.
Class: MixedMultiOptimizer
Allows different parameter groups to use different optimizers. Takes a list of (names, optim) pairs and validates that no parameter is optimized by multiple optimizers.
step(loss, params): Calls each optimizer's step with its assigned parameter subset.get_step(loss, params): Aggregates updated values from all optimizers.
Class: Newton
Second-order optimizer using Newton's method via pyro.ops.newton.newton_step. Supports batched low-dimensional (1D, 2D, or 3D) variables with optional trust region regularization.
get_step(loss, params): For each parameter, callsnewton_step(loss, value, trust_radius).- The result is differentiable (useful for Laplace approximation).
I/O Contract
| Class | Method | Input | Output |
|---|---|---|---|
MultiOptimizer |
step |
loss: Tensor, params: Dict[str, Tensor] |
None (in-place update) |
MultiOptimizer |
get_step |
loss: Tensor, params: Dict[str, Tensor] |
Dict[str, Tensor] (updated values)
|
MixedMultiOptimizer |
__init__ |
parts: List[Tuple[List[str], MultiOptimizer]] |
Instance |
Newton |
__init__ |
trust_radii: Dict[str, float] |
Instance |
Usage Examples
import torch
import pyro
from pyro.optim.multi import Newton, MixedMultiOptimizer, PyroMultiOptimizer
from pyro.optim import Adam
# Newton optimizer with trust region
newton = Newton(trust_radii={"x": 1.0, "y": 0.5})
# Use in inference
tr = pyro.poutine.trace(model).get_trace(*args, **kwargs)
loss = -tr.log_prob_sum()
params = {name: site['value'].unconstrained()
for name, site in tr.nodes.items()
if site['type'] == 'param'}
newton.step(loss, params)
# Mixed optimizer: Newton for some params, Adam for others
adam_optim = PyroMultiOptimizer(Adam({"lr": 0.01}))
mixed = MixedMultiOptimizer([
(["x", "y"], newton), # Newton for x and y
(["z", "w"], adam_optim), # Adam for z and w
])
mixed.step(loss, all_params)
# PyroMultiOptimizer wrapping a standard optimizer
multi_adam = PyroMultiOptimizer(Adam({"lr": 0.01}))
multi_adam.step(loss, params)
Related Pages
- Pyro_ppl_Pyro_NewtonStep -- Newton step implementation used by
Newton - Pyro_ppl_Pyro_AdagradRMSProp -- First-order optimizer for ADVI
- Pyro_ppl_Pyro_PyroLRScheduler -- Learning rate scheduling