Implementation:Pyro ppl Pyro VelocityVerlet
| Property | Value |
|---|---|
| Module | pyro.ops.integrator
|
| Source | pyro/ops/integrator.py |
| Lines | 130 |
| Functions | velocity_verlet, potential_grad, register_exception_handler
|
| Dependencies | torch.autograd
|
Overview
This module implements the velocity Verlet algorithm, a second-order symplectic integrator used in Hamiltonian Monte Carlo (HMC) and NUTS samplers to simulate Hamiltonian dynamics. The integrator numerically solves Hamilton's equations by alternating half-step momentum updates with full-step position updates, preserving the symplectic structure of the dynamics.
The module also provides potential_grad for computing gradients of the potential energy function with respect to parameters, along with an exception handler registry for gracefully handling numerical errors (e.g., singular matrices) during gradient computation.
Code Reference
Function: velocity_verlet
velocity_verlet(z, r, potential_fn, kinetic_grad, step_size, num_steps=1, z_grads=None)
Performs num_steps leapfrog integration steps. Each step consists of:
- Half-step momentum update:
r = r - 0.5 * step_size * grad(potential_fn, z) - Full-step position update:
z = z + step_size * kinetic_grad(r) - Half-step momentum update:
r = r - 0.5 * step_size * grad(potential_fn, z)
Function: potential_grad
Computes the gradient of the potential energy function with respect to the position variables z. Uses torch.autograd.grad and handles exceptions through a global registry of exception handlers.
If an exception matches any registered handler, returns zero gradients and NaN potential energy instead of propagating the error.
Function: register_exception_handler
Registers a named exception handler for numerical errors during potential energy computation. By default, a handler for torch._C._LinAlgError (singular matrices, non-positive-definite inputs) is registered.
I/O Contract
| Function | Input | Output |
|---|---|---|
velocity_verlet |
z: dict (site->Tensor), r: dict (site->Tensor), potential_fn: callable, kinetic_grad: callable, step_size: float, num_steps: int |
Tuple (z_next, r_next, z_grads, potential_energy)
|
potential_grad |
potential_fn: callable, z: dict |
Tuple (z_grads: dict, potential_energy: Tensor)
|
register_exception_handler |
name: str, handler: Callable |
None |
Usage Examples
import torch
from pyro.ops.integrator import velocity_verlet, potential_grad
# Define a simple quadratic potential
def potential_fn(z):
return 0.5 * z["x"].pow(2).sum()
# Define kinetic energy gradient (identity mass matrix)
def kinetic_grad(r):
return {k: v for k, v in r.items()}
# Initialize position and momentum
z = {"x": torch.tensor([1.0, 2.0], requires_grad=True)}
r = {"x": torch.randn(2)}
# Run 10 leapfrog steps
z_new, r_new, grads, pe = velocity_verlet(
z, r, potential_fn, kinetic_grad, step_size=0.1, num_steps=10
)
print(f"New position: {z_new['x']}")
print(f"Potential energy: {pe.item():.4f}")
Related Pages
- Pyro_ppl_Pyro_DualAveraging -- Step size adaptation for HMC
- Pyro_ppl_Pyro_WelfordCovariance -- Mass matrix estimation for HMC