Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro NewtonStep

From Leeroopedia


Property Value
Module pyro.ops.newton
Source pyro/ops/newton.py
Lines 247
Functions newton_step, newton_step_1d, newton_step_2d, newton_step_3d
Dependencies torch.autograd, pyro.ops.linalg, pyro.util

Overview

This module provides batched Newton optimization steps for minimizing loss functions over low-dimensional variables (1D, 2D, or 3D). The implementation supports an optional trust region constraint and produces differentiable outputs, making it suitable for use in Laplace approximation and higher-order optimization within Pyro.

A key property is that the final solution of Newton iteration is differentiable with respect to the inputs even when intermediate steps are detached, due to Newton's quadratic convergence. When the loss is interpreted as a negative log probability density, the returned (mode, cov) pair can be used to construct a Laplace approximation via MultivariateNormal(mode, cov).

Code Reference

Function: newton_step(loss, x, trust_radius)

Dispatches to dimension-specific implementations based on x.shape[-1]:

  • dim=1: newton_step_1d
  • dim=2: newton_step_2d
  • dim=3: newton_step_3d
  • dim>3: Raises NotImplementedError

Function: newton_step_1d

For 1D variables. Computes gradient and Hessian via torch.autograd.grad (with create_graph=True). Clamps the Hessian to be positive, applies the Newton update dx = -g / H, and optionally clamps to trust radius.

Function: newton_step_2d

For 2D variables. Computes full 2x2 Hessian. If trust radius is specified, adds a regularizer based on the minimum eigenvalue to keep updates within the trust region. Uses rinverse for symmetric matrix inversion.

Function: newton_step_3d

For 3D variables. Similar to 2D but uses eig_3d from pyro.ops.linalg for eigenvalue computation and 3x3 symmetric matrix inversion.

I/O Contract

Function Input Output
newton_step(loss, x, trust_radius) loss: Tensor() (scalar), x: Tensor(..., D) with D in {1,2,3}, trust_radius: float or None Tuple (mode: Tensor(..., D), cov: Tensor(..., D, D))
newton_step_1d loss: Tensor(), x: Tensor(..., 1) Tuple (mode, cov) with cov.shape = (..., 1, 1)
newton_step_2d loss: Tensor(), x: Tensor(..., 2) Tuple (mode, cov) with cov.shape = (..., 2, 2)
newton_step_3d loss: Tensor(), x: Tensor(..., 3) Tuple (mode, cov) with cov.shape = (..., 3, 3)

Usage Examples

import torch
from pyro.ops.newton import newton_step

# Optimize a batch of 2D quadratic functions
x = torch.zeros(1000, 2)  # initial value
for step in range(100):
    x = x.detach()          # block gradients through previous steps
    x.requires_grad = True  # ensure loss is differentiable
    loss = (x ** 2).sum()   # simple quadratic loss
    x, cov = newton_step(loss, x, trust_radius=1.0)

# The final x is still differentiable
print(x.requires_grad)  # True

# Use cov for Laplace approximation
mvn = torch.distributions.MultivariateNormal(x[0], cov[0])

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment