Implementation:Pyro ppl Pyro TraceTailAdaptive ELBO
Overview
The trace_tail_adaptive_elbo module (Template:Code) implements TraceTailAdaptive_ELBO, a variational inference objective based on tail-adaptive f-divergences as described in Wang, Liu, and Liu (NeurIPS 2018). This approach adaptively emphasizes particles with low log p/q ratios, thereby focusing optimization on the tail of the importance weight distribution where the guide most needs improvement.
The key idea is to rank particles by their log importance weight ratio log(p/q) and assign rank-based weights gamma_k = rank_k^beta, where beta is a hyperparameter (Template:Code) that controls how much emphasis is placed on particles with poor fit. Higher beta values give more weight to well-fitting particles (closer to standard ELBO), while lower beta values focus more on improving the tail.
Important characteristics:
- This objective does not compute the loss value itself -- the Template:Code method raises Template:Code. It only computes gradients via a surrogate loss.
- For monitoring convergence, use another objective (e.g., Template:Code).
- Requires Template:Code and Template:Code.
- Only supports models with fully reparameterized latent variables.
- Does not support data subsampling.
Code Reference
File: Template:Code
Key Classes
| Class | Parent | Description |
|---|---|---|
| Template:Code | Template:Code | Tail-adaptive f-divergence objective for SVI. Only computes gradients, not the loss value. |
Methods
| Method | Description |
|---|---|
| Template:Code | Not implemented. Raises Template:Code because the tail-adaptive objective does not require loss computation for gradient estimation. |
| Template:Code | Computes particle-specific weights based on rank ordering of log(p/q), then constructs the surrogate loss. Returns Template:Code. |
I/O Contract
Constructor
Inputs (inherited from Trace_ELBO and ELBO):
- Template:Code -- Number of particles. Should be > 1 for adaptive behavior. When Template:Code, falls back to standard Trace_ELBO behavior (with a warning).
- Template:Code -- Must be True. Raises Template:Code if False.
- Template:Code -- Controls the rank-based weighting. Passed through ELBO configuration. Higher values give more uniform weights, lower values focus on the tail.
- Template:Code -- Max nested plate depth.
_differentiable_loss_particle
Inputs:
- Template:Code -- A Pyro trace from the model.
- Template:Code -- A Pyro trace from the guide.
Output:
- Template:Code -- First element is always Template:Code (loss not computed). Second element is the surrogate loss used for gradient computation.
Gradient Computation Logic
- Compute per-particle log p and log q by summing log-probabilities across all sample sites.
- Compute log(p/q) for each particle.
- Rank particles by log(p/q) in ascending order (worst-fitting first).
- Compute rank-based weights: gamma_k = rank_k ^ tail_adaptive_beta (detached).
- Surrogate loss = -sum(log(p/q) * gamma) / sum(gamma).
Validation
At each guide site, Template:Code is called to ensure all latent variables are fully reparameterized. Non-reparameterized sites raise Template:Code.
Usage Examples
Basic Usage with SVI
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, TraceTailAdaptive_ELBO
from pyro.optim import Adam
def model(data):
z = pyro.sample("z", dist.Normal(0, 1))
pyro.sample("obs", dist.Normal(z, 1), obs=data)
def guide(data):
loc = pyro.param("loc", torch.tensor(0.0))
scale = pyro.param("scale", torch.tensor(1.0),
constraint=dist.constraints.positive)
pyro.sample("z", dist.Normal(loc, scale))
# Must use vectorize_particles=True and num_particles > 1
elbo = TraceTailAdaptive_ELBO(
num_particles=20,
vectorize_particles=True,
tail_adaptive_beta=1.0
)
svi = SVI(model, guide, Adam({"lr": 0.01}), loss=elbo)
for step in range(1000):
# Note: svi.step() calls loss_and_grads internally
svi.step(data)
Monitoring Convergence with RenyiELBO
from pyro.infer import RenyiELBO
# Use TraceTailAdaptive for training
train_loss = TraceTailAdaptive_ELBO(num_particles=20, vectorize_particles=True)
# Use RenyiELBO for monitoring (since TraceTailAdaptive doesn't compute loss)
monitor_loss = RenyiELBO(alpha=0, num_particles=20, vectorize_particles=True)
svi = SVI(model, guide, Adam({"lr": 0.01}), loss=train_loss)
for step in range(1000):
svi.step(data)
if step % 100 == 0:
elbo = monitor_loss.loss(model, guide, data)
print(f"Step {step}: ELBO = {-elbo:.4f}")
Related Pages
- Pyro_ppl_Pyro_RenyiELBO -- Recommended for monitoring convergence when using TraceTailAdaptive_ELBO
- Pyro_ppl_Pyro_TraceGraph_ELBO -- Graph-based ELBO for models with non-reparameterizable variables
- Pyro_ppl_Pyro_Infer_Utilities -- Contains Template:Code validation helper
- Pyro_ppl_Pyro_TraceTMC_ELBO -- Another advanced ELBO variant