Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro RenyiELBO

From Leeroopedia


Overview

The renyi_elbo module (Template:Code) implements RenyiELBO, a variational inference objective based on Renyi's alpha-divergence. This generalizes the standard ELBO by introducing an order parameter alpha that interpolates between different divergence measures:

  • alpha = 0: Equivalent to the Importance Weighted Autoencoder (IWAE) objective from Burda et al. (2015).
  • alpha < 1: Provides a tighter lower bound than the standard ELBO.
  • alpha = 1: Equivalent to the standard ELBO (not supported -- use Template:Code instead).
  • alpha > 0: Guarantees a strict lower bound on the log marginal likelihood.
  • alpha < 0: May give better results empirically on some datasets, but is not a strict lower bound.

The objective is:

L_alpha = (1 / (1 - alpha)) * log E[ exp((1 - alpha) * (log p(x,z) - log q(z|x))) ]

where the expectation is over z ~ q(z|x) and is estimated using Template:Code samples.

RenyiELBO requires Template:Code (since a single particle reduces to the standard ELBO regardless of alpha). It supports both vectorized and non-vectorized particle computation, and handles both reparameterized and non-reparameterized (score function) gradient estimators.

Code Reference

File: Template:Code

Key Classes

Class Parent Description
Template:Code Template:Code Renyi alpha-divergence variational inference objective.

RenyiELBO Methods

Method Description
Template:Code Initialize with alpha order and ELBO configuration.
Template:Code Compute the Renyi ELBO estimate (no gradients). Returns a float.
Template:Code Compute the Renyi ELBO and perform backward on the surrogate loss. Returns a float.
Template:Code Returns a single pair of (model_trace, guide_trace) from importance tracing.

I/O Contract

Constructor

Inputs:

  • Template:Code -- Order of the Renyi divergence. Must not equal 1. Default is 0 (IWAE).
  • Template:Code -- Number of particles for the estimator. Default is 2.
  • Template:Code -- Max nested plate depth. Default is infinity (auto-detected).
  • Template:Code -- Whether to vectorize across particles. Default is False.
  • Template:Code -- Warn about enumeration misuse. Default is True.

loss

Inputs:

Output:

loss_and_grads

Inputs:

Output:

  • Template:Code -- Negative Renyi ELBO estimate. Also populates gradients on parameters using a surrogate loss with normalized importance weights.

Gradient Computation Details

For gradient computation, the surrogate loss uses self-normalized importance weights:

  1. Compute per-particle ELBO: elbo_k = log p(x,z_k) - log q(z_k|x)
  2. Compute log-weights: w_k = (1 - alpha) * elbo_k
  3. Normalize: w_tilde_k = softmax(w_k)
  4. Surrogate loss: -sum(w_tilde_k * surrogate_elbo_k) / num_particles

For non-reparameterizable sites with a non-zero score function term, an additional REINFORCE-like correction term proportional to alpha/(1-alpha) is added. Mixed reparameterized/non-reparameterized entropy terms raise Template:Code.

Usage Examples

IWAE Objective (alpha=0)

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, RenyiELBO
from pyro.optim import Adam

def model(data):
    z = pyro.sample("z", dist.Normal(0, 1))
    pyro.sample("obs", dist.Normal(z, 1), obs=data)

def guide(data):
    loc = pyro.param("loc", torch.tensor(0.0))
    scale = pyro.param("scale", torch.tensor(1.0),
                       constraint=dist.constraints.positive)
    pyro.sample("z", dist.Normal(loc, scale))

# IWAE with 10 particles
elbo = RenyiELBO(alpha=0, num_particles=10)
svi = SVI(model, guide, Adam({"lr": 0.01}), loss=elbo)

for step in range(1000):
    loss = svi.step(data)

Tighter Bound with alpha < 1

# alpha=0.5 gives a bound between IWAE and standard ELBO
elbo = RenyiELBO(alpha=0.5, num_particles=5, vectorize_particles=True)
svi = SVI(model, guide, Adam({"lr": 0.01}), loss=elbo)

for step in range(1000):
    loss = svi.step(data)

Negative Alpha for Better Results

# Negative alpha may give better results on some datasets
# but is no longer a strict lower bound
elbo = RenyiELBO(alpha=-1.0, num_particles=10)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment