Implementation:Pyro ppl Pyro RenyiELBO
Overview
The renyi_elbo module (Template:Code) implements RenyiELBO, a variational inference objective based on Renyi's alpha-divergence. This generalizes the standard ELBO by introducing an order parameter alpha that interpolates between different divergence measures:
- alpha = 0: Equivalent to the Importance Weighted Autoencoder (IWAE) objective from Burda et al. (2015).
- alpha < 1: Provides a tighter lower bound than the standard ELBO.
- alpha = 1: Equivalent to the standard ELBO (not supported -- use Template:Code instead).
- alpha > 0: Guarantees a strict lower bound on the log marginal likelihood.
- alpha < 0: May give better results empirically on some datasets, but is not a strict lower bound.
The objective is:
- L_alpha = (1 / (1 - alpha)) * log E[ exp((1 - alpha) * (log p(x,z) - log q(z|x))) ]
where the expectation is over z ~ q(z|x) and is estimated using Template:Code samples.
RenyiELBO requires Template:Code (since a single particle reduces to the standard ELBO regardless of alpha). It supports both vectorized and non-vectorized particle computation, and handles both reparameterized and non-reparameterized (score function) gradient estimators.
Code Reference
File: Template:Code
Key Classes
| Class | Parent | Description |
|---|---|---|
| Template:Code | Template:Code | Renyi alpha-divergence variational inference objective. |
RenyiELBO Methods
| Method | Description |
|---|---|
| Template:Code | Initialize with alpha order and ELBO configuration. |
| Template:Code | Compute the Renyi ELBO estimate (no gradients). Returns a float. |
| Template:Code | Compute the Renyi ELBO and perform backward on the surrogate loss. Returns a float. |
| Template:Code | Returns a single pair of (model_trace, guide_trace) from importance tracing. |
I/O Contract
Constructor
Inputs:
- Template:Code -- Order of the Renyi divergence. Must not equal 1. Default is 0 (IWAE).
- Template:Code -- Number of particles for the estimator. Default is 2.
- Template:Code -- Max nested plate depth. Default is infinity (auto-detected).
- Template:Code -- Whether to vectorize across particles. Default is False.
- Template:Code -- Warn about enumeration misuse. Default is True.
loss
Inputs:
- Template:Code -- A Pyro model callable.
- Template:Code -- A Pyro guide callable.
- Template:Code -- Passed to model and guide.
Output:
- Template:Code -- Negative Renyi ELBO estimate. Computed under Template:Code.
loss_and_grads
Inputs:
- Same as Template:Code.
Output:
- Template:Code -- Negative Renyi ELBO estimate. Also populates gradients on parameters using a surrogate loss with normalized importance weights.
Gradient Computation Details
For gradient computation, the surrogate loss uses self-normalized importance weights:
- Compute per-particle ELBO: elbo_k = log p(x,z_k) - log q(z_k|x)
- Compute log-weights: w_k = (1 - alpha) * elbo_k
- Normalize: w_tilde_k = softmax(w_k)
- Surrogate loss: -sum(w_tilde_k * surrogate_elbo_k) / num_particles
For non-reparameterizable sites with a non-zero score function term, an additional REINFORCE-like correction term proportional to alpha/(1-alpha) is added. Mixed reparameterized/non-reparameterized entropy terms raise Template:Code.
Usage Examples
IWAE Objective (alpha=0)
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, RenyiELBO
from pyro.optim import Adam
def model(data):
z = pyro.sample("z", dist.Normal(0, 1))
pyro.sample("obs", dist.Normal(z, 1), obs=data)
def guide(data):
loc = pyro.param("loc", torch.tensor(0.0))
scale = pyro.param("scale", torch.tensor(1.0),
constraint=dist.constraints.positive)
pyro.sample("z", dist.Normal(loc, scale))
# IWAE with 10 particles
elbo = RenyiELBO(alpha=0, num_particles=10)
svi = SVI(model, guide, Adam({"lr": 0.01}), loss=elbo)
for step in range(1000):
loss = svi.step(data)
Tighter Bound with alpha < 1
# alpha=0.5 gives a bound between IWAE and standard ELBO
elbo = RenyiELBO(alpha=0.5, num_particles=5, vectorize_particles=True)
svi = SVI(model, guide, Adam({"lr": 0.01}), loss=elbo)
for step in range(1000):
loss = svi.step(data)
Negative Alpha for Better Results
# Negative alpha may give better results on some datasets
# but is no longer a strict lower bound
elbo = RenyiELBO(alpha=-1.0, num_particles=10)
Related Pages
- Pyro_ppl_Pyro_TraceGraph_ELBO -- Graph-based ELBO with variance reduction for non-reparameterizable models
- Pyro_ppl_Pyro_TraceTailAdaptive_ELBO -- Adaptive f-divergence that uses RenyiELBO for convergence monitoring
- Pyro_ppl_Pyro_TraceTMC_ELBO -- Tensor Monte Carlo ELBO using enumeration
- Pyro_ppl_Pyro_Infer_Utilities -- Utility functions for validation and tensor operations