Principle:Pyro ppl Pyro Enumeration Configuration
Metadata
| Field | Value |
|---|---|
| Principle ID | Pyro_ppl_Pyro_Enumeration_Configuration |
| Title | Enumeration Configuration |
| Project | Pyro (pyro-ppl/pyro) |
| Domains | Discrete_Inference, Variational_Inference |
| Implementation | Pyro_ppl_Pyro_Config_Enumerate |
| Repository | https://github.com/pyro-ppl/pyro |
Summary
Enumeration Configuration is the principle of configuring discrete sample sites for exhaustive enumeration during variational inference in Pyro. By annotating guide sites with an enumeration strategy, Pyro can automatically enumerate over all possible values of discrete latent variables, enabling exact marginalization rather than relying on noisy Monte Carlo estimates.
Motivation
In probabilistic models that contain discrete latent variables (e.g., mixture assignments, hidden Markov states, discrete switches), standard gradient-based variational inference faces a fundamental challenge: the REINFORCE gradient estimator for discrete variables suffers from high variance, leading to slow and unstable training. Enumeration configuration addresses this by instructing the inference engine to sum over all possible values of discrete variables rather than sampling from them, thereby computing exact expected values and eliminating sampling variance entirely for those sites.
Core Concepts
Enumeration Strategies
Pyro supports two enumeration strategies for discrete variables:
- "parallel" (default): All possible values of a discrete variable are evaluated simultaneously by expanding the tensor along a new dimension. This is computationally efficient because it leverages PyTorch's tensor parallelism. Each enumerated value occupies a distinct position along a dedicated enumeration dimension (counting from the left of the tensor).
- "sequential": Each possible value of the discrete variable is evaluated one at a time in a loop. This uses less memory than parallel enumeration but is slower. Sequential enumeration is useful when the support of a distribution is very large and parallel expansion would exhaust GPU memory.
Eligible Sample Sites
Only distributions that have has_enumerate_support = True can be exhaustively enumerated. This includes:
dist.Categorical-- discrete choices over K categoriesdist.Bernoulli-- binary 0/1 choicesdist.OneHotCategorical-- one-hot encoded categoricaldist.Binomial(with small total_count)
Continuous distributions cannot be exhaustively enumerated, though they can be locally approximated via Monte Carlo sampling with the num_samples parameter.
Integration with TraceEnum_ELBO
Enumeration configuration is designed to work with TraceEnum_ELBO, which is the ELBO loss function that performs exact marginalization over enumerated discrete variables. The workflow is:
- Annotate the guide with enumeration configuration using
config_enumerate - Use
TraceEnum_ELBOas the loss function in SVI - The ELBO computation sums over all discrete configurations, weighted by their probabilities
This combination transforms the variational inference problem for discrete variables from a stochastic optimization into a deterministic computation (for the discrete part), dramatically improving convergence.
Expand vs. Non-Expand Mode
When expand=False (default), enumerated values are not expanded across plates, saving memory. When expand=True, enumerated values are explicitly broadcast, which is required in certain advanced use cases involving dependent discrete variables across plates.
How It Works
The enumeration configuration mechanism operates at the level of Pyro's effect handling system (poutines):
- The
config_enumeratefunction wraps a guide with aninfer_configpoutine - This poutine inspects each sample site and, if the site's distribution has
has_enumerate_support == True, annotates the site'sinferdict with{"enumerate": "parallel"}(or"sequential") - During ELBO computation, the
EnumMessengerintercepts these annotated sites and replaces the sampled value with the full support of the distribution - The resulting tensors carry extra enumeration dimensions, which
TraceEnum_ELBOcontracts out using the tensor variable elimination algorithm
Example
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate
from pyro.optim import Adam
def model(data):
# Mixture model with K=3 components
weights = pyro.sample("weights", dist.Dirichlet(torch.ones(3)))
with pyro.plate("data", len(data)):
# Discrete assignment (to be enumerated)
assignment = pyro.sample("assignment", dist.Categorical(weights))
pyro.sample("obs", dist.Normal(locs[assignment], 0.1), obs=data)
# Option 1: Use as a decorator
@config_enumerate
def guide(data):
weights_q = pyro.param("weights_q", torch.ones(3),
constraint=dist.constraints.simplex)
pyro.sample("weights", dist.Dirichlet(weights_q))
with pyro.plate("data", len(data)):
probs = pyro.param("probs", torch.ones(len(data), 3) / 3,
constraint=dist.constraints.simplex)
pyro.sample("assignment", dist.Categorical(probs))
# Option 2: Use as a function
# guide = config_enumerate(guide, default="parallel")
# Use TraceEnum_ELBO for exact discrete marginalization
svi = SVI(model, guide, Adam({"lr": 0.01}),
loss=TraceEnum_ELBO(max_plate_nesting=1))
Relationship to Other Principles
- Pyro_ppl_Pyro_Tensor_Variable_Elimination -- The underlying algorithm that contracts enumerated tensor factors. Enumeration configuration creates the annotated tensors; TVE performs the efficient contraction.
- Pyro_ppl_Pyro_Markov_Dependency -- For sequential models like HMMs, Markov dependency declarations enable memory-efficient enumeration by limiting the history window.
- Pyro_ppl_Pyro_Discrete_Posterior_Decoding -- After training with enumerated inference,
infer_discreteuses the same enumeration mechanism to decode the posterior of discrete latent variables.
Related Pages
Implemented By
References
- Pyro enumeration tutorial: https://pyro.ai/examples/enumeration.html
- Obermeyer et al., "Tensor Variable Elimination for Plated Factor Graphs", 2019 (https://arxiv.org/abs/1902.03210)