Implementation:Pyro ppl Pyro TraceEnum ELBO Loss
| Field | Value |
|---|---|
| Sources | Pyro |
| Domains | Variational_Inference, Discrete_Inference |
| Last Updated | 2026-02-09 12:00 GMT |
Overview
TraceEnum_ELBO is an ELBO implementation that supports exhaustive enumeration over discrete sample sites and local parallel sampling, providing much lower variance gradient estimates than sampling-based approaches for models with discrete latent variables.
Description
The TraceEnum_ELBO class extends the base ELBO to handle discrete latent variables through exact marginalization rather than sampling. It supports two enumeration modes:
- Parallel enumeration (
infer={"enumerate": "parallel"}): Evaluates all discrete values simultaneously by expanding tensor dimensions, leveraging GPU parallelism for efficient computation. - Sequential enumeration (
infer={"enumerate": "sequential"}): Iterates over discrete values one at a time, consuming less memory at the cost of speed.
Internally, TraceEnum_ELBO uses EnumMessenger to allocate dedicated tensor dimensions for each enumerated variable, and contract_tensor_tree (tensor variable elimination) to efficiently compute the marginal log probabilities. The Dice estimator handles the interaction between enumerated discrete variables and score function terms from non-reparameterizable continuous variables.
Critical requirement: When using parallel enumeration, max_plate_nesting must be set correctly to indicate how many rightmost tensor dimensions are reserved for pyro.plate contexts. Enumeration dimensions are allocated to the left of these plate dimensions.
Beyond the standard ELBO methods, this class provides:
compute_marginals(): Computes marginal distributions at each model-enumerated sample sitesample_posterior(): Samples from the joint posterior of all model-enumerated sites using forward filtering / backward sampling
Usage
Import TraceEnum_ELBO for variational inference in models with discrete latent variables. Use @config_enumerate to globally enable enumeration on guide sites, or set infer={"enumerate": "parallel"} on individual sample sites.
Code Reference
Source Location
- Repository
pyro-ppl/pyro- File
pyro/infer/traceenum_elbo.py- Lines
- L316--521
- Base class
pyro/infer/elbo.pyL30--239
Signature
class TraceEnum_ELBO(ELBO):
def __init__(
self,
num_particles=1,
max_plate_nesting=float('inf'),
vectorize_particles=False,
strict_enumeration_warning=True,
ignore_jit_warnings=False,
jit_options=None,
retain_graph=None,
tail_adaptive_beta=-1.0,
):
Import
from pyro.infer import TraceEnum_ELBO
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
num_particles |
int | No | Number of particles for the ELBO estimator (default: 1) |
max_plate_nesting |
int or float | Yes (for parallel enumeration) | Bound on max number of nested pyro.plate contexts; must be set correctly for parallel enumeration
|
vectorize_particles |
bool | No | Whether to vectorize ELBO computation over particles (default: False) |
strict_enumeration_warning |
bool | No | Whether to warn if no sites are configured for enumeration (default: True) |
ignore_jit_warnings |
bool | No | Flag to ignore JIT tracer warnings (default: False) |
jit_options |
dict or None | No | Options to pass to torch.jit.trace (default: None)
|
retain_graph |
bool or None | No | Whether to retain autograd graph during backward pass (default: None) |
tail_adaptive_beta |
float | No | Exponent for tail-adaptive ELBO variant (default: -1.0) |
Outputs
| Method | Return Type | Description |
|---|---|---|
loss(model, guide, *args, **kwargs) |
float | Scalar ELBO estimate with enumerated discrete variables marginalized out |
differentiable_loss(model, guide, *args, **kwargs) |
torch.Tensor | Differentiable ELBO estimate for autograd |
loss_and_grads(model, guide, *args, **kwargs) |
float | Scalar ELBO estimate; performs backward pass on each particle |
compute_marginals(model, guide, *args, **kwargs) |
OrderedDict | Dict mapping site name to marginal Distribution for model-enumerated sites |
sample_posterior(model, guide, *args, **kwargs) |
trace | Joint posterior sample from all model-enumerated sites via backward sampling |
Usage Examples
Discrete Mixture Model
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate
from pyro.optim import Adam
@config_enumerate
def model(data):
weights = pyro.param("weights", torch.ones(3) / 3,
constraint=dist.constraints.simplex)
locs = pyro.param("locs", torch.randn(3))
with pyro.plate("data", len(data)):
assignment = pyro.sample("assignment", dist.Categorical(weights))
pyro.sample("obs", dist.Normal(locs[assignment], 1.0), obs=data)
def guide(data):
pass # Discrete sites are marginalized, not guided
# max_plate_nesting=1 because we have one plate ("data")
svi = SVI(model, guide, Adam({"lr": 0.01}),
loss=TraceEnum_ELBO(max_plate_nesting=1))
for step in range(1000):
loss = svi.step(data)
Computing Marginals
elbo = TraceEnum_ELBO(max_plate_nesting=1)
marginals = elbo.compute_marginals(model, guide, data)
# marginals["assignment"] is a Categorical distribution over cluster assignments