Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Pyro ppl Pyro TraceEnum ELBO Loss

From Leeroopedia


Field Value
Sources Pyro
Domains Variational_Inference, Discrete_Inference
Last Updated 2026-02-09 12:00 GMT

Overview

TraceEnum_ELBO is an ELBO implementation that supports exhaustive enumeration over discrete sample sites and local parallel sampling, providing much lower variance gradient estimates than sampling-based approaches for models with discrete latent variables.

Description

The TraceEnum_ELBO class extends the base ELBO to handle discrete latent variables through exact marginalization rather than sampling. It supports two enumeration modes:

  • Parallel enumeration (infer={"enumerate": "parallel"}): Evaluates all discrete values simultaneously by expanding tensor dimensions, leveraging GPU parallelism for efficient computation.
  • Sequential enumeration (infer={"enumerate": "sequential"}): Iterates over discrete values one at a time, consuming less memory at the cost of speed.

Internally, TraceEnum_ELBO uses EnumMessenger to allocate dedicated tensor dimensions for each enumerated variable, and contract_tensor_tree (tensor variable elimination) to efficiently compute the marginal log probabilities. The Dice estimator handles the interaction between enumerated discrete variables and score function terms from non-reparameterizable continuous variables.

Critical requirement: When using parallel enumeration, max_plate_nesting must be set correctly to indicate how many rightmost tensor dimensions are reserved for pyro.plate contexts. Enumeration dimensions are allocated to the left of these plate dimensions.

Beyond the standard ELBO methods, this class provides:

  • compute_marginals(): Computes marginal distributions at each model-enumerated sample site
  • sample_posterior(): Samples from the joint posterior of all model-enumerated sites using forward filtering / backward sampling

Usage

Import TraceEnum_ELBO for variational inference in models with discrete latent variables. Use @config_enumerate to globally enable enumeration on guide sites, or set infer={"enumerate": "parallel"} on individual sample sites.

Code Reference

Source Location

Repository
pyro-ppl/pyro
File
pyro/infer/traceenum_elbo.py
Lines
L316--521
Base class
pyro/infer/elbo.py L30--239

Signature

class TraceEnum_ELBO(ELBO):
    def __init__(
        self,
        num_particles=1,
        max_plate_nesting=float('inf'),
        vectorize_particles=False,
        strict_enumeration_warning=True,
        ignore_jit_warnings=False,
        jit_options=None,
        retain_graph=None,
        tail_adaptive_beta=-1.0,
    ):

Import

from pyro.infer import TraceEnum_ELBO

I/O Contract

Inputs

Name Type Required Description
num_particles int No Number of particles for the ELBO estimator (default: 1)
max_plate_nesting int or float Yes (for parallel enumeration) Bound on max number of nested pyro.plate contexts; must be set correctly for parallel enumeration
vectorize_particles bool No Whether to vectorize ELBO computation over particles (default: False)
strict_enumeration_warning bool No Whether to warn if no sites are configured for enumeration (default: True)
ignore_jit_warnings bool No Flag to ignore JIT tracer warnings (default: False)
jit_options dict or None No Options to pass to torch.jit.trace (default: None)
retain_graph bool or None No Whether to retain autograd graph during backward pass (default: None)
tail_adaptive_beta float No Exponent for tail-adaptive ELBO variant (default: -1.0)

Outputs

Method Return Type Description
loss(model, guide, *args, **kwargs) float Scalar ELBO estimate with enumerated discrete variables marginalized out
differentiable_loss(model, guide, *args, **kwargs) torch.Tensor Differentiable ELBO estimate for autograd
loss_and_grads(model, guide, *args, **kwargs) float Scalar ELBO estimate; performs backward pass on each particle
compute_marginals(model, guide, *args, **kwargs) OrderedDict Dict mapping site name to marginal Distribution for model-enumerated sites
sample_posterior(model, guide, *args, **kwargs) trace Joint posterior sample from all model-enumerated sites via backward sampling

Usage Examples

Discrete Mixture Model

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate
from pyro.optim import Adam

@config_enumerate
def model(data):
    weights = pyro.param("weights", torch.ones(3) / 3,
                          constraint=dist.constraints.simplex)
    locs = pyro.param("locs", torch.randn(3))
    with pyro.plate("data", len(data)):
        assignment = pyro.sample("assignment", dist.Categorical(weights))
        pyro.sample("obs", dist.Normal(locs[assignment], 1.0), obs=data)

def guide(data):
    pass  # Discrete sites are marginalized, not guided

# max_plate_nesting=1 because we have one plate ("data")
svi = SVI(model, guide, Adam({"lr": 0.01}),
          loss=TraceEnum_ELBO(max_plate_nesting=1))

for step in range(1000):
    loss = svi.step(data)

Computing Marginals

elbo = TraceEnum_ELBO(max_plate_nesting=1)
marginals = elbo.compute_marginals(model, guide, data)
# marginals["assignment"] is a Categorical distribution over cluster assignments

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment