Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Pyro ppl Pyro Config Enumerate

From Leeroopedia


Metadata

Field Value
Implementation ID Pyro_ppl_Pyro_Config_Enumerate
Title config_enumerate
Project Pyro (pyro-ppl/pyro)
File pyro/infer/enum.py, Lines 138-220
Implements Pyro_ppl_Pyro_Enumeration_Configuration
Repository https://github.com/pyro-ppl/pyro

Summary

config_enumerate is a function (and decorator) that annotates all relevant sample sites in a Pyro guide with enumeration configuration. It is the primary user-facing API for enabling exhaustive enumeration of discrete latent variables during variational inference with TraceEnum_ELBO.

Signature

def config_enumerate(
    guide=None,
    default="parallel",
    expand=False,
    num_samples=None,
    tmc="diagonal"
)

Import

from pyro.infer import config_enumerate

Parameters

Parameter Type Default Description
guide callable or None None A Pyro guide (variational distribution) callable. If None, returns a decorator.
default str or None "parallel" Enumeration strategy: "parallel" (tensor-based, efficient), "sequential" (loop-based), or None (disable).
expand bool False Whether to expand enumerated sample values. When False, saves memory by avoiding explicit broadcasting across plates.
num_samples int or None None If not None, use local Monte Carlo sampling with this many samples instead of exhaustive enumeration. Applies to both continuous and discrete distributions.
tmc str or None "diagonal" Tensor Monte Carlo strategy: "diagonal" or "mixture". Only relevant when num_samples is not None.

Returns

Type Description
callable The annotated guide callable, wrapped with an infer_config poutine that sets enumeration metadata on each eligible sample site.

Usage Patterns

As a Decorator

@config_enumerate
def guide(data):
    # All discrete sites with has_enumerate_support will be
    # annotated with enumerate="parallel"
    assignment = pyro.sample("assignment", dist.Categorical(probs))
    ...

As a Decorator with Arguments

@config_enumerate(default="sequential", expand=True)
def guide(data):
    assignment = pyro.sample("assignment", dist.Categorical(probs))
    ...

As a Function

guide = config_enumerate(guide, default="parallel")

With Local Monte Carlo Sampling

@config_enumerate(default="parallel", num_samples=10)
def guide(data):
    # All sites (including continuous) will be locally sampled
    z = pyro.sample("z", dist.Normal(loc, scale))
    ...

Internal Mechanism

The function delegates to poutine.infer_config with a configuration function built by _config_enumerate. The internal _config_fn (lines 114-131) inspects each sample site and:

  1. Skips observed sites and subsample sites
  2. If num_samples is set, annotates all unobserved sample sites with {"enumerate": default, "num_samples": num_samples, "expand": expand, "tmc": tmc}
  3. If num_samples is None, only annotates sites whose distribution has has_enumerate_support == True with {"enumerate": default, "expand": expand}
  4. Does not overwrite existing per-site infer={"enumerate": ...} annotations

Validation

The function performs input validation (lines 184-210):

  • default must be one of "sequential", "parallel", "flat", or None
  • expand must be True or False
  • num_samples must be None or a positive integer
  • "sequential" is not supported with num_samples (local Monte Carlo requires parallel evaluation)
  • When tmc="full" and num_samples > 1, expand is forced to True

Complete Example

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate
from pyro.optim import Adam

# Gaussian Mixture Model
K = 3  # number of components

def model(data):
    weights = pyro.sample("weights", dist.Dirichlet(0.5 * torch.ones(K)))
    with pyro.plate("components", K):
        locs = pyro.sample("locs", dist.Normal(0.0, 10.0))
        scales = pyro.sample("scales", dist.LogNormal(0.0, 2.0))
    with pyro.plate("data", len(data)):
        assignment = pyro.sample("assignment", dist.Categorical(weights))
        pyro.sample("obs", dist.Normal(locs[assignment], scales[assignment]), obs=data)

@config_enumerate
def guide(data):
    weights_q = pyro.param("weights_q", torch.ones(K),
                           constraint=dist.constraints.simplex)
    pyro.sample("weights", dist.Dirichlet(weights_q))
    with pyro.plate("components", K):
        loc_q = pyro.param("loc_q", torch.randn(K))
        scale_q = pyro.param("scale_q", torch.ones(K),
                             constraint=dist.constraints.positive)
        pyro.sample("locs", dist.Normal(loc_q, scale_q))
        pyro.sample("scales", dist.LogNormal(loc_q, scale_q))
    with pyro.plate("data", len(data)):
        assignment_probs = pyro.param("assignment_probs",
                                      torch.ones(len(data), K) / K,
                                      constraint=dist.constraints.simplex)
        pyro.sample("assignment", dist.Categorical(assignment_probs))

# Train with exact enumeration of discrete variables
elbo = TraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(model, guide, Adam({"lr": 0.005}), loss=elbo)

data = torch.randn(100)
for step in range(1000):
    loss = svi.step(data)

Source Code Reference

The implementation is located in pyro/infer/enum.py at lines 138-220. The core logic is:

def config_enumerate(
    guide=None, default="parallel", expand=False, num_samples=None, tmc="diagonal"
):
    # ... validation ...

    # Support usage as a decorator:
    if guide is None:
        return lambda guide: config_enumerate(
            guide, default=default, expand=expand, num_samples=num_samples, tmc=tmc
        )

    return poutine.infer_config(
        guide, config_fn=_config_enumerate(default, expand, num_samples, tmc)
    )

Related Pages

Implements Principle

Related Implementations

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment