Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro EnergyDistance

From Leeroopedia


Overview

The energy_distance module (Template:Code) implements posterior predictive Energy Distance as an inference objective, following the statistical framework of Szekely and Rizzo (2003) and Gneiting and Raftery (2007). This is a likelihood-free inference algorithm, meaning it does not require tractable density functions for the likelihood -- only the ability to sample from the model.

The energy distance loss is defined as:

loss = E[||X - x||^beta] - 0.5 * E[||X - X'||^beta] - lambda * E[log p(Z)]

where X and X' are independent posterior predictive samples (drawn by sampling Z from the guide then X from the model conditioned on Z), x is the observed data, beta is a robustness parameter, and lambda controls Bayesian prior regularization.

Key properties:

  • Likelihood-free: Works for models where the likelihood has no tractable density (e.g., simulator-based models).
  • Robust: The beta-energy distance is well-defined for any distribution with finite fractional moment E[||X||^beta]. For heavy-tailed distributions (e.g., Cauchy), use beta < 1.
  • Strictly proper: The loss is a strictly proper scoring rule for 0 < beta < 2.

Requirements:

  • Static model structure.
  • Fully reparameterized guide.
  • Reparameterized likelihood distributions in the model (latent priors may be non-reparameterized).
  • At least 2 particles (Template:Code).

Code Reference

File: Template:Code

Key Classes

Class Parent Description
Template:Code -- Posterior predictive energy distance loss with optional Bayesian prior regularization.

EnergyDistance Methods

Method Description
Template:Code Initialize with robustness exponent, prior regularization scale, particle count, and plate nesting bound.
Template:Code Compute the surrogate loss. Returns a differentiable tensor.
Template:Code Trace the guide and model (with unconditioned observations) to obtain posterior predictive samples.
Template:Code Not implemented. Raises Template:Code.

Internal Helper

Function Description
Template:Code Computes squared L2 error between samples x and y, respecting scale and mask from plate contexts.

I/O Contract

Constructor

Inputs:

  • Template:Code -- Exponent for the energy distance. Must be in the open interval (0, 2). Default is 1.0. Lower values are more robust to heavy tails.
  • Template:Code -- Nonnegative scale for prior regularization. Default is 0.0 (no prior). Model parameters are trained only if this is positive.
  • Template:Code -- Number of particles/samples. Must be at least 2. Default is 2.
  • Template:Code -- Bound on max nested plate contexts. Default is infinity (auto-detected).

__call__

Inputs:

Output:

  • Template:Code -- A differentiable scalar loss combining energy distance and optional prior regularization.

Internal Flow

  1. Guide is traced to sample latent variables Z.
  2. Model is replayed with guide's Z values but with observations unconditioned, producing posterior predictive samples X.
  3. Energy distance is computed from: E[||X - x||^beta] (error term) and E[||X - X'||^beta] (entropy term using pairs of particles).
  4. Optionally, log-prior is computed for latent sites and subtracted with Template:Code.

Usage Examples

Basic Likelihood-Free Inference

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, EnergyDistance
from pyro.optim import Adam

def model(data):
    mu = pyro.sample("mu", dist.Normal(0, 10))
    sigma = pyro.sample("sigma", dist.LogNormal(0, 1))
    with pyro.plate("data", len(data)):
        pyro.sample("obs", dist.Normal(mu, sigma), obs=data)

def guide(data):
    loc = pyro.param("loc", torch.tensor(0.0))
    scale = pyro.param("scale", torch.tensor(1.0),
                       constraint=dist.constraints.positive)
    pyro.sample("mu", dist.Normal(loc, scale))
    s_loc = pyro.param("s_loc", torch.tensor(0.0))
    pyro.sample("sigma", dist.LogNormal(s_loc, 0.5))

data = torch.randn(50) + 3.0
energy_dist = EnergyDistance(beta=1.0, num_particles=10)

# Use as a loss function with SVI
optim = Adam({"lr": 0.01})
for step in range(1000):
    loss = energy_dist(model, guide, data)
    loss.backward()
    # manual optimizer step (or use with SVI)

With Prior Regularization

energy_dist = EnergyDistance(
    beta=1.0,
    prior_scale=0.1,      # Enable Bayesian prior regularization
    num_particles=20,
    max_plate_nesting=1
)

loss = energy_dist(model, guide, data)

Heavy-Tailed Distributions

# For Cauchy-distributed data, use beta < 1 for strict properness
energy_dist = EnergyDistance(beta=0.5, num_particles=50)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment