Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro MultiOptimizer

From Leeroopedia


Property Value
Module pyro.optim.multi
Source pyro/optim/multi.py
Lines 169
Classes MultiOptimizer, PyroMultiOptimizer, TorchMultiOptimizer, MixedMultiOptimizer, Newton
Dependencies torch, pyro.ops.newton, pyro.optim.optim

Overview

This module provides a framework for higher-order optimizers in Pyro. Unlike standard first-order optimizers that call loss.backward() externally and then update parameters, higher-order optimizers use torch.autograd.grad internally and may require multiple backward passes through the loss.

The MultiOptimizer base class defines an interface where the step method receives both the loss tensor and a dictionary of parameters, triggering backpropagation internally. This enables second-order methods like Newton's method and supports mixing different optimization strategies for different parameter groups.

Code Reference

Class: MultiOptimizer (Base)

Methods:

  • step(loss, params): In-place optimization step. Computes updated values via get_step and copies them (detached) into the original parameters.
  • get_step(loss, params): Abstract. Returns a dict of updated parameter values (preserving differentiability).

Class: PyroMultiOptimizer

Wraps a PyroOptim in the MultiOptimizer interface. Uses torch.autograd.grad(loss, values, create_graph=True) to compute gradients, then assigns them as .grad attributes and calls the underlying optimizer.

Class: TorchMultiOptimizer

Wraps a PyTorch Optimizer class in the MultiOptimizer interface. First wraps the optimizer in PyroOptim, then delegates to PyroMultiOptimizer.

Class: MixedMultiOptimizer

Allows different parameter groups to use different optimizers. Takes a list of (names, optim) pairs and validates that no parameter is optimized by multiple optimizers.

  • step(loss, params): Calls each optimizer's step with its assigned parameter subset.
  • get_step(loss, params): Aggregates updated values from all optimizers.

Class: Newton

Second-order optimizer using Newton's method via pyro.ops.newton.newton_step. Supports batched low-dimensional (1D, 2D, or 3D) variables with optional trust region regularization.

  • get_step(loss, params): For each parameter, calls newton_step(loss, value, trust_radius).
  • The result is differentiable (useful for Laplace approximation).

I/O Contract

Class Method Input Output
MultiOptimizer step loss: Tensor, params: Dict[str, Tensor] None (in-place update)
MultiOptimizer get_step loss: Tensor, params: Dict[str, Tensor] Dict[str, Tensor] (updated values)
MixedMultiOptimizer __init__ parts: List[Tuple[List[str], MultiOptimizer]] Instance
Newton __init__ trust_radii: Dict[str, float] Instance

Usage Examples

import torch
import pyro
from pyro.optim.multi import Newton, MixedMultiOptimizer, PyroMultiOptimizer
from pyro.optim import Adam

# Newton optimizer with trust region
newton = Newton(trust_radii={"x": 1.0, "y": 0.5})

# Use in inference
tr = pyro.poutine.trace(model).get_trace(*args, **kwargs)
loss = -tr.log_prob_sum()
params = {name: site['value'].unconstrained()
          for name, site in tr.nodes.items()
          if site['type'] == 'param'}
newton.step(loss, params)

# Mixed optimizer: Newton for some params, Adam for others
adam_optim = PyroMultiOptimizer(Adam({"lr": 0.01}))
mixed = MixedMultiOptimizer([
    (["x", "y"], newton),   # Newton for x and y
    (["z", "w"], adam_optim),  # Adam for z and w
])
mixed.step(loss, all_params)

# PyroMultiOptimizer wrapping a standard optimizer
multi_adam = PyroMultiOptimizer(Adam({"lr": 0.01}))
multi_adam.step(loss, params)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment