Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Pyro ppl Pyro Enumeration Configuration

From Leeroopedia


Metadata

Field Value
Principle ID Pyro_ppl_Pyro_Enumeration_Configuration
Title Enumeration Configuration
Project Pyro (pyro-ppl/pyro)
Domains Discrete_Inference, Variational_Inference
Implementation Pyro_ppl_Pyro_Config_Enumerate
Repository https://github.com/pyro-ppl/pyro

Summary

Enumeration Configuration is the principle of configuring discrete sample sites for exhaustive enumeration during variational inference in Pyro. By annotating guide sites with an enumeration strategy, Pyro can automatically enumerate over all possible values of discrete latent variables, enabling exact marginalization rather than relying on noisy Monte Carlo estimates.

Motivation

In probabilistic models that contain discrete latent variables (e.g., mixture assignments, hidden Markov states, discrete switches), standard gradient-based variational inference faces a fundamental challenge: the REINFORCE gradient estimator for discrete variables suffers from high variance, leading to slow and unstable training. Enumeration configuration addresses this by instructing the inference engine to sum over all possible values of discrete variables rather than sampling from them, thereby computing exact expected values and eliminating sampling variance entirely for those sites.

Core Concepts

Enumeration Strategies

Pyro supports two enumeration strategies for discrete variables:

  • "parallel" (default): All possible values of a discrete variable are evaluated simultaneously by expanding the tensor along a new dimension. This is computationally efficient because it leverages PyTorch's tensor parallelism. Each enumerated value occupies a distinct position along a dedicated enumeration dimension (counting from the left of the tensor).
  • "sequential": Each possible value of the discrete variable is evaluated one at a time in a loop. This uses less memory than parallel enumeration but is slower. Sequential enumeration is useful when the support of a distribution is very large and parallel expansion would exhaust GPU memory.

Eligible Sample Sites

Only distributions that have has_enumerate_support = True can be exhaustively enumerated. This includes:

  • dist.Categorical -- discrete choices over K categories
  • dist.Bernoulli -- binary 0/1 choices
  • dist.OneHotCategorical -- one-hot encoded categorical
  • dist.Binomial (with small total_count)

Continuous distributions cannot be exhaustively enumerated, though they can be locally approximated via Monte Carlo sampling with the num_samples parameter.

Integration with TraceEnum_ELBO

Enumeration configuration is designed to work with TraceEnum_ELBO, which is the ELBO loss function that performs exact marginalization over enumerated discrete variables. The workflow is:

  1. Annotate the guide with enumeration configuration using config_enumerate
  2. Use TraceEnum_ELBO as the loss function in SVI
  3. The ELBO computation sums over all discrete configurations, weighted by their probabilities

This combination transforms the variational inference problem for discrete variables from a stochastic optimization into a deterministic computation (for the discrete part), dramatically improving convergence.

Expand vs. Non-Expand Mode

When expand=False (default), enumerated values are not expanded across plates, saving memory. When expand=True, enumerated values are explicitly broadcast, which is required in certain advanced use cases involving dependent discrete variables across plates.

How It Works

The enumeration configuration mechanism operates at the level of Pyro's effect handling system (poutines):

  1. The config_enumerate function wraps a guide with an infer_config poutine
  2. This poutine inspects each sample site and, if the site's distribution has has_enumerate_support == True, annotates the site's infer dict with {"enumerate": "parallel"} (or "sequential")
  3. During ELBO computation, the EnumMessenger intercepts these annotated sites and replaces the sampled value with the full support of the distribution
  4. The resulting tensors carry extra enumeration dimensions, which TraceEnum_ELBO contracts out using the tensor variable elimination algorithm

Example

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate
from pyro.optim import Adam

def model(data):
    # Mixture model with K=3 components
    weights = pyro.sample("weights", dist.Dirichlet(torch.ones(3)))
    with pyro.plate("data", len(data)):
        # Discrete assignment (to be enumerated)
        assignment = pyro.sample("assignment", dist.Categorical(weights))
        pyro.sample("obs", dist.Normal(locs[assignment], 0.1), obs=data)

# Option 1: Use as a decorator
@config_enumerate
def guide(data):
    weights_q = pyro.param("weights_q", torch.ones(3),
                           constraint=dist.constraints.simplex)
    pyro.sample("weights", dist.Dirichlet(weights_q))
    with pyro.plate("data", len(data)):
        probs = pyro.param("probs", torch.ones(len(data), 3) / 3,
                           constraint=dist.constraints.simplex)
        pyro.sample("assignment", dist.Categorical(probs))

# Option 2: Use as a function
# guide = config_enumerate(guide, default="parallel")

# Use TraceEnum_ELBO for exact discrete marginalization
svi = SVI(model, guide, Adam({"lr": 0.01}),
          loss=TraceEnum_ELBO(max_plate_nesting=1))

Relationship to Other Principles

  • Pyro_ppl_Pyro_Tensor_Variable_Elimination -- The underlying algorithm that contracts enumerated tensor factors. Enumeration configuration creates the annotated tensors; TVE performs the efficient contraction.
  • Pyro_ppl_Pyro_Markov_Dependency -- For sequential models like HMMs, Markov dependency declarations enable memory-efficient enumeration by limiting the history window.
  • Pyro_ppl_Pyro_Discrete_Posterior_Decoding -- After training with enumerated inference, infer_discrete uses the same enumeration mechanism to decode the posterior of discrete latent variables.

Related Pages

Implemented By

References

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment