Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro SVGD

From Leeroopedia


Overview

The svgd module (Template:Code) implements Stein Variational Gradient Descent (SVGD), a non-parametric variational inference algorithm that maintains a set of particles to approximate the posterior distribution. Unlike traditional variational inference that optimizes a parametric guide, SVGD iteratively transports particles to match the target posterior using a combination of an attractive gradient (driving particles toward high-probability regions) and a repulsive gradient (maintaining diversity among particles).

The module provides:

  • SVGD -- The main inference class that orchestrates particle-based variational inference.
  • SteinKernel -- An abstract base class for kernels used to compute particle interactions.
  • RBFSteinKernel -- A Radial Basis Function kernel with median-heuristic bandwidth selection.
  • IMQSteinKernel -- An Inverse Multi-Quadratic kernel, which has heavier tails than RBF.

SVGD supports two modes of operation: Template:Code (using separate per-dimension kernels, following Kernelized Complete Conditional Stein Discrepancy) and Template:Code (using a single joint kernel, as in the original SVGD paper).

Internally, SVGD uses an Template:Code (a modification of Template:Code) that stores particles as a Pyro parameter and returns a Delta distribution over those particles.

Code Reference

File: Template:Code

Key Classes

Class Parent Description
Template:Code -- Main SVGD inference class. Manages particles, kernel computation, and gradient steps.
Template:Code Template:Code (ABCMeta) Abstract base class for Stein kernels. Subclasses must implement Template:Code.
Template:Code Template:Code RBF (Gaussian) kernel with median-heuristic bandwidth.
Template:Code Template:Code Inverse Multi-Quadratic kernel: K(x,y) = (alpha + x-y ^2/h)^beta.
Template:Code Template:Code Internal guide that represents the particle set as a Delta distribution.

SVGD Methods

Method Description
Template:Code Initialize SVGD with model, kernel, optimizer, particle count, and mode (Template:Code or Template:Code).
Template:Code Compute the SVGD gradient and take a single optimization step. Returns dict of mean squared gradients per parameter.
Template:Code Returns a dict mapping latent variable names to constrained particle values (shape: num_particles x event_shape).

SteinKernel Interface

Method Description
Template:Code Takes particles tensor of shape (N, D). Returns a pair Template:Code both of shape (N, N, D). Template:Code is the derivative of Template:Code w.r.t. x_{m,d}.

I/O Contract

SVGD Constructor

Inputs:

SVGD.step

Inputs:

Output:

  • Template:Code -- Maps latent variable names to mean squared gradient values (float). Useful for monitoring convergence.

SVGD.get_named_particles

Output:

RBFSteinKernel

Constructor Input:

  • Template:Code -- Scaling factor for bandwidth (default: None, meaning no extra scaling).

IMQSteinKernel

Constructor Inputs:

Usage Examples

Basic SVGD Inference

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVGD, RBFSteinKernel
from pyro.optim import Adam

def model():
    mu = pyro.sample("mu", dist.Normal(0, 10))
    sigma = pyro.sample("sigma", dist.LogNormal(0, 2))
    with pyro.plate("data", len(data)):
        pyro.sample("obs", dist.Normal(mu, sigma), obs=data)

kernel = RBFSteinKernel()
optim = Adam({"lr": 0.1})
svgd = SVGD(model, kernel, optim, num_particles=50, max_plate_nesting=1)

for step in range(500):
    squared_grads = svgd.step()

# Retrieve posterior particles
particles = svgd.get_named_particles()
print("mu particles:", particles["mu"])
print("sigma particles:", particles["sigma"])

Using IMQ Kernel

from pyro.infer import IMQSteinKernel

kernel = IMQSteinKernel(alpha=0.5, beta=-0.5, bandwidth_factor=1.0)
svgd = SVGD(model, kernel, optim, num_particles=100,
             max_plate_nesting=1, mode="multivariate")

for step in range(1000):
    svgd.step()

Monitoring Convergence

for step in range(500):
    squared_grads = svgd.step()
    if step % 100 == 0:
        for name, grad_val in squared_grads.items():
            print(f"Step {step}, {name}: mean sq grad = {grad_val:.6f}")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment