Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro VelocityVerlet

From Leeroopedia


Property Value
Module pyro.ops.integrator
Source pyro/ops/integrator.py
Lines 130
Functions velocity_verlet, potential_grad, register_exception_handler
Dependencies torch.autograd

Overview

This module implements the velocity Verlet algorithm, a second-order symplectic integrator used in Hamiltonian Monte Carlo (HMC) and NUTS samplers to simulate Hamiltonian dynamics. The integrator numerically solves Hamilton's equations by alternating half-step momentum updates with full-step position updates, preserving the symplectic structure of the dynamics.

The module also provides potential_grad for computing gradients of the potential energy function with respect to parameters, along with an exception handler registry for gracefully handling numerical errors (e.g., singular matrices) during gradient computation.

Code Reference

Function: velocity_verlet

velocity_verlet(z, r, potential_fn, kinetic_grad, step_size, num_steps=1, z_grads=None)

Performs num_steps leapfrog integration steps. Each step consists of:

  1. Half-step momentum update: r = r - 0.5 * step_size * grad(potential_fn, z)
  2. Full-step position update: z = z + step_size * kinetic_grad(r)
  3. Half-step momentum update: r = r - 0.5 * step_size * grad(potential_fn, z)

Function: potential_grad

Computes the gradient of the potential energy function with respect to the position variables z. Uses torch.autograd.grad and handles exceptions through a global registry of exception handlers.

If an exception matches any registered handler, returns zero gradients and NaN potential energy instead of propagating the error.

Function: register_exception_handler

Registers a named exception handler for numerical errors during potential energy computation. By default, a handler for torch._C._LinAlgError (singular matrices, non-positive-definite inputs) is registered.

I/O Contract

Function Input Output
velocity_verlet z: dict (site->Tensor), r: dict (site->Tensor), potential_fn: callable, kinetic_grad: callable, step_size: float, num_steps: int Tuple (z_next, r_next, z_grads, potential_energy)
potential_grad potential_fn: callable, z: dict Tuple (z_grads: dict, potential_energy: Tensor)
register_exception_handler name: str, handler: Callable None

Usage Examples

import torch
from pyro.ops.integrator import velocity_verlet, potential_grad

# Define a simple quadratic potential
def potential_fn(z):
    return 0.5 * z["x"].pow(2).sum()

# Define kinetic energy gradient (identity mass matrix)
def kinetic_grad(r):
    return {k: v for k, v in r.items()}

# Initialize position and momentum
z = {"x": torch.tensor([1.0, 2.0], requires_grad=True)}
r = {"x": torch.randn(2)}

# Run 10 leapfrog steps
z_new, r_new, grads, pe = velocity_verlet(
    z, r, potential_fn, kinetic_grad, step_size=0.1, num_steps=10
)
print(f"New position: {z_new['x']}")
print(f"Potential energy: {pe.item():.4f}")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment