Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Workflow:Pyro ppl Pyro Discrete Enumeration

From Leeroopedia


Knowledge Sources
Domains Probabilistic_Programming, Discrete_Inference, Hidden_Markov_Models
Last Updated 2026-02-09 09:00 GMT

Overview

End-to-end process for performing inference in models with discrete latent variables by exhaustively enumerating their support and analytically marginalizing them out using Pyro's TraceEnum_ELBO.

Description

This workflow describes the procedure for training models containing discrete latent variables that cannot be handled by standard gradient-based methods. Rather than sampling discrete variables (which produces high-variance gradients), Pyro's enumeration machinery exactly marginalizes over all possible values of discrete variables by computing a weighted sum over their support. This is implemented through the EnumMessenger effect handler and the TraceEnum_ELBO loss function, which together perform tensor variable elimination — an efficient algorithm for computing marginals in plated factor graphs. The technique is essential for Hidden Markov Models, mixture models, topic models, and any model with categorical or discrete latent structure.

Usage

Execute this workflow when your model contains discrete latent variables (Categorical, Bernoulli, OneHotCategorical) whose support is small enough to enumerate. This is appropriate for Hidden Markov Models, Gaussian mixture models, topic models, semi-supervised classification, and any model where discrete variables can be analytically summed out. The enumeration approach produces zero-variance gradient estimates for the discrete components, dramatically improving training stability and speed compared to REINFORCE-style estimators.

Execution Steps

Step 1: Define Model with Enumerable Discrete Variables

Write a probabilistic model containing discrete latent variables. Each discrete sample site that should be enumerated must be annotated with infer={"enumerate": "parallel"} in its pyro.sample call. The "parallel" option creates a new tensor dimension for each enumerated value, allowing vectorized computation across all configurations. Use pyro.plate for batch dimensions and pyro.markov for sequential dependencies (enabling efficient message-passing in chain-structured models like HMMs).

Key considerations:

  • Annotate discrete sites with infer={"enumerate": "parallel"}
  • "parallel" enumeration adds tensor dimensions; "sequential" loops over values
  • Use pyro.markov() context manager for Markov chain dependencies
  • Plate dimensions must be specified with explicit dim= arguments to avoid conflicts with enumeration dimensions
  • Enumeration dimensions grow from the left (negative indices)

Step 2: Define the Guide for Continuous Variables

Write a guide that provides variational distributions only for the continuous latent variables in the model. Discrete variables marked for enumeration are handled automatically by TraceEnum_ELBO and must not appear in the guide. Use AutoGuide classes (AutoNormal, AutoDelta) or write a custom guide for the continuous parameters. For models with no continuous latent variables (e.g., fully-discrete HMMs with point-estimate parameters), use an empty guide.

Key considerations:

  • Do NOT include enumerated discrete variables in the guide
  • Guide should only cover continuous latent variables and global parameters
  • config_enumerate() decorator can be used to set enumeration globally
  • For semi-supervised models, labeled and unlabeled data may require different enumeration patterns

Step 3: Configure TraceEnum_ELBO

Select TraceEnum_ELBO (or JitTraceEnum_ELBO for compiled execution) as the loss function. This specialized ELBO estimator integrates with the enumeration messenger to compute exact marginal log-likelihoods for the discrete components. Configure max_plate_nesting to indicate the maximum depth of nested plate contexts, which is needed to correctly distinguish plate dimensions from enumeration dimensions.

Key considerations:

  • max_plate_nesting must equal or exceed the actual plate nesting depth
  • JitTraceEnum_ELBO provides speedup but requires stable tensor shapes
  • The strict_enumeration_warning flag helps debug enumeration mismatches
  • num_particles controls Monte Carlo samples for continuous variables

Step 4: Set Up Optimizer and SVI

Create a Pyro optimizer and SVI instance as in the standard SVI workflow. The optimizer updates parameters of both the model (e.g., transition probabilities in an HMM) and the guide (variational parameters for continuous latents). ClippedAdam is commonly used for stability when training discrete models.

Key considerations:

  • Learning rate may need tuning for discrete models
  • Gradient clipping helps prevent instability from sharp discrete boundaries
  • Clear the param store before training

Step 5: Run Training with Enumerated ELBO

Execute the SVI training loop by calling svi.step() on each data batch. Behind the scenes, TraceEnum_ELBO runs the model with EnumMessenger active, creating tensor dimensions for each discrete configuration, then contracts (sums) over these dimensions using the tensor variable elimination algorithm. The result is an exact marginal likelihood contribution from discrete variables combined with sampled estimates for continuous variables.

Key considerations:

  • Training is typically slower per step than Trace_ELBO due to enumeration overhead
  • Memory usage scales with the product of enumerated variable cardinalities
  • For HMMs, pyro.markov() keeps memory O(state_dim^2) per timestep rather than exponential
  • Monitor loss convergence — enumeration eliminates discrete gradient variance

Step 6: Perform Discrete Inference on Trained Model

After training, use infer_discrete to compute MAP or marginal assignments for the discrete latent variables. Pass the trained model and posterior parameters through infer_discrete with first_available_dim set appropriately. This uses the Viterbi algorithm (for MAP) or forward-backward algorithm (for marginals) to decode the most likely discrete variable assignments.

Key considerations:

  • infer_discrete(model, first_available_dim=..., temperature=0) gives MAP (Viterbi)
  • infer_discrete with temperature=1 gives marginal samples
  • first_available_dim must not overlap with plate or enumeration dimensions
  • Results include decoded discrete sequences and their probabilities

Execution Diagram

GitHub URL

Workflow Repository