Implementation:Pyro ppl Pyro SVGD
Overview
The svgd module (Template:Code) implements Stein Variational Gradient Descent (SVGD), a non-parametric variational inference algorithm that maintains a set of particles to approximate the posterior distribution. Unlike traditional variational inference that optimizes a parametric guide, SVGD iteratively transports particles to match the target posterior using a combination of an attractive gradient (driving particles toward high-probability regions) and a repulsive gradient (maintaining diversity among particles).
The module provides:
- SVGD -- The main inference class that orchestrates particle-based variational inference.
- SteinKernel -- An abstract base class for kernels used to compute particle interactions.
- RBFSteinKernel -- A Radial Basis Function kernel with median-heuristic bandwidth selection.
- IMQSteinKernel -- An Inverse Multi-Quadratic kernel, which has heavier tails than RBF.
SVGD supports two modes of operation: Template:Code (using separate per-dimension kernels, following Kernelized Complete Conditional Stein Discrepancy) and Template:Code (using a single joint kernel, as in the original SVGD paper).
Internally, SVGD uses an Template:Code (a modification of Template:Code) that stores particles as a Pyro parameter and returns a Delta distribution over those particles.
Code Reference
File: Template:Code
Key Classes
| Class | Parent | Description | ||
|---|---|---|---|---|
| Template:Code | -- | Main SVGD inference class. Manages particles, kernel computation, and gradient steps. | ||
| Template:Code | Template:Code (ABCMeta) | Abstract base class for Stein kernels. Subclasses must implement Template:Code. | ||
| Template:Code | Template:Code | RBF (Gaussian) kernel with median-heuristic bandwidth. | ||
| Template:Code | Template:Code | Inverse Multi-Quadratic kernel: K(x,y) = (alpha + | x-y | ^2/h)^beta. |
| Template:Code | Template:Code | Internal guide that represents the particle set as a Delta distribution. |
SVGD Methods
| Method | Description |
|---|---|
| Template:Code | Initialize SVGD with model, kernel, optimizer, particle count, and mode (Template:Code or Template:Code). |
| Template:Code | Compute the SVGD gradient and take a single optimization step. Returns dict of mean squared gradients per parameter. |
| Template:Code | Returns a dict mapping latent variable names to constrained particle values (shape: num_particles x event_shape). |
SteinKernel Interface
| Method | Description |
|---|---|
| Template:Code | Takes particles tensor of shape (N, D). Returns a pair Template:Code both of shape (N, N, D). Template:Code is the derivative of Template:Code w.r.t. x_{m,d}. |
I/O Contract
SVGD Constructor
Inputs:
- Template:Code -- A fully vectorized Pyro model with only continuous latent variables.
- Template:Code -- An instance of a Stein kernel (e.g., Template:Code).
- Template:Code -- A Pyro optimizer wrapper.
- Template:Code -- Number of particles (must be > 1).
- Template:Code -- Maximum number of nested plate contexts in the model (must be >= 0).
- Template:Code -- Either Template:Code (default) or Template:Code.
SVGD.step
Inputs:
- Template:Code -- Passed to the model.
Output:
- Template:Code -- Maps latent variable names to mean squared gradient values (float). Useful for monitoring convergence.
SVGD.get_named_particles
Output:
- Template:Code -- Maps latent variable names to tensors of shape Template:Code in the constrained space.
RBFSteinKernel
Constructor Input:
- Template:Code -- Scaling factor for bandwidth (default: None, meaning no extra scaling).
IMQSteinKernel
Constructor Inputs:
- Template:Code -- Kernel hyperparameter (default 0.5, must be > 0).
- Template:Code -- Kernel hyperparameter (default -0.5, must be < 0).
- Template:Code -- Bandwidth scaling factor.
Usage Examples
Basic SVGD Inference
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVGD, RBFSteinKernel
from pyro.optim import Adam
def model():
mu = pyro.sample("mu", dist.Normal(0, 10))
sigma = pyro.sample("sigma", dist.LogNormal(0, 2))
with pyro.plate("data", len(data)):
pyro.sample("obs", dist.Normal(mu, sigma), obs=data)
kernel = RBFSteinKernel()
optim = Adam({"lr": 0.1})
svgd = SVGD(model, kernel, optim, num_particles=50, max_plate_nesting=1)
for step in range(500):
squared_grads = svgd.step()
# Retrieve posterior particles
particles = svgd.get_named_particles()
print("mu particles:", particles["mu"])
print("sigma particles:", particles["sigma"])
Using IMQ Kernel
from pyro.infer import IMQSteinKernel
kernel = IMQSteinKernel(alpha=0.5, beta=-0.5, bandwidth_factor=1.0)
svgd = SVGD(model, kernel, optim, num_particles=100,
max_plate_nesting=1, mode="multivariate")
for step in range(1000):
svgd.step()
Monitoring Convergence
for step in range(500):
squared_grads = svgd.step()
if step % 100 == 0:
for name, grad_val in squared_grads.items():
print(f"Step {step}, {name}: mean sq grad = {grad_val:.6f}")
Related Pages
- Pyro_ppl_Pyro_Infer_Utilities -- Utility functions including Template:Code used by SVGD
- Pyro_ppl_Pyro_TraceGraph_ELBO -- Alternative ELBO-based inference for non-reparameterizable models
- Pyro_ppl_Pyro_EnergyDistance -- Another likelihood-free inference method