Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro TraceTailAdaptive ELBO

From Leeroopedia


Overview

The trace_tail_adaptive_elbo module (Template:Code) implements TraceTailAdaptive_ELBO, a variational inference objective based on tail-adaptive f-divergences as described in Wang, Liu, and Liu (NeurIPS 2018). This approach adaptively emphasizes particles with low log p/q ratios, thereby focusing optimization on the tail of the importance weight distribution where the guide most needs improvement.

The key idea is to rank particles by their log importance weight ratio log(p/q) and assign rank-based weights gamma_k = rank_k^beta, where beta is a hyperparameter (Template:Code) that controls how much emphasis is placed on particles with poor fit. Higher beta values give more weight to well-fitting particles (closer to standard ELBO), while lower beta values focus more on improving the tail.

Important characteristics:

  • This objective does not compute the loss value itself -- the Template:Code method raises Template:Code. It only computes gradients via a surrogate loss.
  • For monitoring convergence, use another objective (e.g., Template:Code).
  • Requires Template:Code and Template:Code.
  • Only supports models with fully reparameterized latent variables.
  • Does not support data subsampling.

Code Reference

File: Template:Code

Key Classes

Class Parent Description
Template:Code Template:Code Tail-adaptive f-divergence objective for SVI. Only computes gradients, not the loss value.

Methods

Method Description
Template:Code Not implemented. Raises Template:Code because the tail-adaptive objective does not require loss computation for gradient estimation.
Template:Code Computes particle-specific weights based on rank ordering of log(p/q), then constructs the surrogate loss. Returns Template:Code.

I/O Contract

Constructor

Inputs (inherited from Trace_ELBO and ELBO):

  • Template:Code -- Number of particles. Should be > 1 for adaptive behavior. When Template:Code, falls back to standard Trace_ELBO behavior (with a warning).
  • Template:Code -- Must be True. Raises Template:Code if False.
  • Template:Code -- Controls the rank-based weighting. Passed through ELBO configuration. Higher values give more uniform weights, lower values focus on the tail.
  • Template:Code -- Max nested plate depth.

_differentiable_loss_particle

Inputs:

Output:

  • Template:Code -- First element is always Template:Code (loss not computed). Second element is the surrogate loss used for gradient computation.

Gradient Computation Logic

  1. Compute per-particle log p and log q by summing log-probabilities across all sample sites.
  2. Compute log(p/q) for each particle.
  3. Rank particles by log(p/q) in ascending order (worst-fitting first).
  4. Compute rank-based weights: gamma_k = rank_k ^ tail_adaptive_beta (detached).
  5. Surrogate loss = -sum(log(p/q) * gamma) / sum(gamma).

Validation

At each guide site, Template:Code is called to ensure all latent variables are fully reparameterized. Non-reparameterized sites raise Template:Code.

Usage Examples

Basic Usage with SVI

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, TraceTailAdaptive_ELBO
from pyro.optim import Adam

def model(data):
    z = pyro.sample("z", dist.Normal(0, 1))
    pyro.sample("obs", dist.Normal(z, 1), obs=data)

def guide(data):
    loc = pyro.param("loc", torch.tensor(0.0))
    scale = pyro.param("scale", torch.tensor(1.0),
                       constraint=dist.constraints.positive)
    pyro.sample("z", dist.Normal(loc, scale))

# Must use vectorize_particles=True and num_particles > 1
elbo = TraceTailAdaptive_ELBO(
    num_particles=20,
    vectorize_particles=True,
    tail_adaptive_beta=1.0
)

svi = SVI(model, guide, Adam({"lr": 0.01}), loss=elbo)
for step in range(1000):
    # Note: svi.step() calls loss_and_grads internally
    svi.step(data)

Monitoring Convergence with RenyiELBO

from pyro.infer import RenyiELBO

# Use TraceTailAdaptive for training
train_loss = TraceTailAdaptive_ELBO(num_particles=20, vectorize_particles=True)

# Use RenyiELBO for monitoring (since TraceTailAdaptive doesn't compute loss)
monitor_loss = RenyiELBO(alpha=0, num_particles=20, vectorize_particles=True)

svi = SVI(model, guide, Adam({"lr": 0.01}), loss=train_loss)
for step in range(1000):
    svi.step(data)
    if step % 100 == 0:
        elbo = monitor_loss.loss(model, guide, data)
        print(f"Step {step}: ELBO = {-elbo:.4f}")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment