Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Pyro ppl Pyro Discrete Marginal ELBO

From Leeroopedia


Metadata
Sources Tensor Variable Elimination for Plated Factor Graphs
Domains Variational_Inference, Discrete_Inference
Last Updated 2026-02-09 12:00 GMT

Overview

Exact marginalization of discrete latent variables in the ELBO, replacing high-variance sampling with exhaustive enumeration over all possible discrete values using tensor variable elimination for efficient computation.

Description

Standard Monte Carlo estimation of the ELBO samples discrete latent variables, which yields high-variance gradient estimates that slow convergence. The discrete marginal ELBO instead enumerates all possible values of discrete latent variables and sums their contributions analytically. This exact marginalization eliminates the variance from discrete sampling entirely, producing much lower variance gradient estimates.

The key insight is that discrete variables with finite support can be summed out exactly rather than sampled. For a model with discrete latent variable d taking values in a finite set, instead of drawing a sample d ~ q(d), we compute:

ELBO=dq(d)[logp(x,z,d)logq(d)logq(z)]

In Pyro, this is accomplished by marking discrete sample sites with infer={"enumerate": "parallel"}, which instructs the inference engine to create a tensor dimension for each enumerated variable. Parallel enumeration evaluates all values simultaneously along dedicated tensor dimensions, leveraging GPU parallelism. Sequential enumeration iterates over values one at a time, using less memory at the cost of speed.

The computation is made efficient through tensor variable elimination, which exploits the conditional independence structure (expressed via pyro.plate) to avoid computing the full joint over all discrete variables. Instead, it contracts tensors in an order that minimizes intermediate computation, analogous to the variable elimination algorithm in graphical models but operating on batched tensors.

Usage

Use this principle when your probabilistic model contains discrete latent variables with finite support (e.g., categorical, Bernoulli, or discrete mixture components). It is especially valuable for:

  • Mixture models with discrete cluster assignments
  • Hidden Markov models with discrete hidden states
  • Models with discrete switching variables

Discrete sites in the model that do not appear in the guide are automatically enumerated. Discrete sites in the guide must be explicitly marked with infer={"enumerate": "parallel"} or configured globally via config_enumerate. The max_plate_nesting parameter must be set correctly to inform the enumerator how many tensor dimensions are reserved for pyro.plate contexts.

Theoretical Basis

For a model with continuous latents z and discrete latents d:

ELBO=𝔼q(z)[dq(d)(logp(x,z,d)logq(z)logq(d))]

The inner sum over d is computed exactly, while the outer expectation over z is still estimated via Monte Carlo. The tensor variable elimination algorithm computes the inner sum efficiently by:

  1. Building a factor graph from the log probability tensors
  2. Identifying an elimination ordering that minimizes the cost of tensor contractions
  3. Performing the contractions using einsum-like operations

The Dice estimator is used to correctly handle the interaction between enumerated discrete variables and the score function terms from non-reparameterizable continuous variables.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment