Implementation:Pyro ppl Pyro NewtonStep
| Property | Value |
|---|---|
| Module | pyro.ops.newton
|
| Source | pyro/ops/newton.py |
| Lines | 247 |
| Functions | newton_step, newton_step_1d, newton_step_2d, newton_step_3d
|
| Dependencies | torch.autograd, pyro.ops.linalg, pyro.util
|
Overview
This module provides batched Newton optimization steps for minimizing loss functions over low-dimensional variables (1D, 2D, or 3D). The implementation supports an optional trust region constraint and produces differentiable outputs, making it suitable for use in Laplace approximation and higher-order optimization within Pyro.
A key property is that the final solution of Newton iteration is differentiable with respect to the inputs even when intermediate steps are detached, due to Newton's quadratic convergence. When the loss is interpreted as a negative log probability density, the returned (mode, cov) pair can be used to construct a Laplace approximation via MultivariateNormal(mode, cov).
Code Reference
Function: newton_step(loss, x, trust_radius)
Dispatches to dimension-specific implementations based on x.shape[-1]:
dim=1:newton_step_1ddim=2:newton_step_2ddim=3:newton_step_3ddim>3: RaisesNotImplementedError
Function: newton_step_1d
For 1D variables. Computes gradient and Hessian via torch.autograd.grad (with create_graph=True). Clamps the Hessian to be positive, applies the Newton update dx = -g / H, and optionally clamps to trust radius.
Function: newton_step_2d
For 2D variables. Computes full 2x2 Hessian. If trust radius is specified, adds a regularizer based on the minimum eigenvalue to keep updates within the trust region. Uses rinverse for symmetric matrix inversion.
Function: newton_step_3d
For 3D variables. Similar to 2D but uses eig_3d from pyro.ops.linalg for eigenvalue computation and 3x3 symmetric matrix inversion.
I/O Contract
| Function | Input | Output |
|---|---|---|
newton_step(loss, x, trust_radius) |
loss: Tensor() (scalar), x: Tensor(..., D) with D in {1,2,3}, trust_radius: float or None |
Tuple (mode: Tensor(..., D), cov: Tensor(..., D, D))
|
newton_step_1d |
loss: Tensor(), x: Tensor(..., 1) |
Tuple (mode, cov) with cov.shape = (..., 1, 1)
|
newton_step_2d |
loss: Tensor(), x: Tensor(..., 2) |
Tuple (mode, cov) with cov.shape = (..., 2, 2)
|
newton_step_3d |
loss: Tensor(), x: Tensor(..., 3) |
Tuple (mode, cov) with cov.shape = (..., 3, 3)
|
Usage Examples
import torch
from pyro.ops.newton import newton_step
# Optimize a batch of 2D quadratic functions
x = torch.zeros(1000, 2) # initial value
for step in range(100):
x = x.detach() # block gradients through previous steps
x.requires_grad = True # ensure loss is differentiable
loss = (x ** 2).sum() # simple quadratic loss
x, cov = newton_step(loss, x, trust_radius=1.0)
# The final x is still differentiable
print(x.requires_grad) # True
# Use cov for Laplace approximation
mvn = torch.distributions.MultivariateNormal(x[0], cov[0])
Related Pages
- Pyro_ppl_Pyro_MultiOptimizer --
Newtonoptimizer wrapsnewton_step - Pyro_ppl_Pyro_Util -- Provides
warn_if_nanused for gradient checking