Implementation:Pyro ppl Pyro EnergyDistance
Overview
The energy_distance module (Template:Code) implements posterior predictive Energy Distance as an inference objective, following the statistical framework of Szekely and Rizzo (2003) and Gneiting and Raftery (2007). This is a likelihood-free inference algorithm, meaning it does not require tractable density functions for the likelihood -- only the ability to sample from the model.
The energy distance loss is defined as:
- loss = E[||X - x||^beta] - 0.5 * E[||X - X'||^beta] - lambda * E[log p(Z)]
where X and X' are independent posterior predictive samples (drawn by sampling Z from the guide then X from the model conditioned on Z), x is the observed data, beta is a robustness parameter, and lambda controls Bayesian prior regularization.
Key properties:
- Likelihood-free: Works for models where the likelihood has no tractable density (e.g., simulator-based models).
- Robust: The beta-energy distance is well-defined for any distribution with finite fractional moment E[||X||^beta]. For heavy-tailed distributions (e.g., Cauchy), use beta < 1.
- Strictly proper: The loss is a strictly proper scoring rule for 0 < beta < 2.
Requirements:
- Static model structure.
- Fully reparameterized guide.
- Reparameterized likelihood distributions in the model (latent priors may be non-reparameterized).
- At least 2 particles (Template:Code).
Code Reference
File: Template:Code
Key Classes
| Class | Parent | Description |
|---|---|---|
| Template:Code | -- | Posterior predictive energy distance loss with optional Bayesian prior regularization. |
EnergyDistance Methods
| Method | Description |
|---|---|
| Template:Code | Initialize with robustness exponent, prior regularization scale, particle count, and plate nesting bound. |
| Template:Code | Compute the surrogate loss. Returns a differentiable tensor. |
| Template:Code | Trace the guide and model (with unconditioned observations) to obtain posterior predictive samples. |
| Template:Code | Not implemented. Raises Template:Code. |
Internal Helper
| Function | Description |
|---|---|
| Template:Code | Computes squared L2 error between samples x and y, respecting scale and mask from plate contexts. |
I/O Contract
Constructor
Inputs:
- Template:Code -- Exponent for the energy distance. Must be in the open interval (0, 2). Default is 1.0. Lower values are more robust to heavy tails.
- Template:Code -- Nonnegative scale for prior regularization. Default is 0.0 (no prior). Model parameters are trained only if this is positive.
- Template:Code -- Number of particles/samples. Must be at least 2. Default is 2.
- Template:Code -- Bound on max nested plate contexts. Default is infinity (auto-detected).
__call__
Inputs:
- Template:Code -- A Pyro model callable.
- Template:Code -- A Pyro guide callable (fully reparameterized).
- Template:Code -- Passed to model and guide.
Output:
- Template:Code -- A differentiable scalar loss combining energy distance and optional prior regularization.
Internal Flow
- Guide is traced to sample latent variables Z.
- Model is replayed with guide's Z values but with observations unconditioned, producing posterior predictive samples X.
- Energy distance is computed from: E[||X - x||^beta] (error term) and E[||X - X'||^beta] (entropy term using pairs of particles).
- Optionally, log-prior is computed for latent sites and subtracted with Template:Code.
Usage Examples
Basic Likelihood-Free Inference
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, EnergyDistance
from pyro.optim import Adam
def model(data):
mu = pyro.sample("mu", dist.Normal(0, 10))
sigma = pyro.sample("sigma", dist.LogNormal(0, 1))
with pyro.plate("data", len(data)):
pyro.sample("obs", dist.Normal(mu, sigma), obs=data)
def guide(data):
loc = pyro.param("loc", torch.tensor(0.0))
scale = pyro.param("scale", torch.tensor(1.0),
constraint=dist.constraints.positive)
pyro.sample("mu", dist.Normal(loc, scale))
s_loc = pyro.param("s_loc", torch.tensor(0.0))
pyro.sample("sigma", dist.LogNormal(s_loc, 0.5))
data = torch.randn(50) + 3.0
energy_dist = EnergyDistance(beta=1.0, num_particles=10)
# Use as a loss function with SVI
optim = Adam({"lr": 0.01})
for step in range(1000):
loss = energy_dist(model, guide, data)
loss.backward()
# manual optimizer step (or use with SVI)
With Prior Regularization
energy_dist = EnergyDistance(
beta=1.0,
prior_scale=0.1, # Enable Bayesian prior regularization
num_particles=20,
max_plate_nesting=1
)
loss = energy_dist(model, guide, data)
Heavy-Tailed Distributions
# For Cauchy-distributed data, use beta < 1 for strict properness
energy_dist = EnergyDistance(beta=0.5, num_particles=50)
Related Pages
- Pyro_ppl_Pyro_Trace_MMD -- Another divergence-based objective using Maximum Mean Discrepancy
- Pyro_ppl_Pyro_SVGD -- Non-parametric particle-based inference
- Pyro_ppl_Pyro_Infer_Utilities -- Validation utilities used during trace processing
- Pyro_ppl_Pyro_Importance -- Importance sampling inference