Implementation:Pyro ppl Pyro Config Enumerate
Metadata
| Field | Value |
|---|---|
| Implementation ID | Pyro_ppl_Pyro_Config_Enumerate |
| Title | config_enumerate |
| Project | Pyro (pyro-ppl/pyro) |
| File | pyro/infer/enum.py, Lines 138-220
|
| Implements | Pyro_ppl_Pyro_Enumeration_Configuration |
| Repository | https://github.com/pyro-ppl/pyro |
Summary
config_enumerate is a function (and decorator) that annotates all relevant sample sites in a Pyro guide with enumeration configuration. It is the primary user-facing API for enabling exhaustive enumeration of discrete latent variables during variational inference with TraceEnum_ELBO.
Signature
def config_enumerate(
guide=None,
default="parallel",
expand=False,
num_samples=None,
tmc="diagonal"
)
Import
from pyro.infer import config_enumerate
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
guide |
callable or None | None | A Pyro guide (variational distribution) callable. If None, returns a decorator. |
default |
str or None | "parallel" | Enumeration strategy: "parallel" (tensor-based, efficient), "sequential" (loop-based), or None (disable). |
expand |
bool | False | Whether to expand enumerated sample values. When False, saves memory by avoiding explicit broadcasting across plates. |
num_samples |
int or None | None | If not None, use local Monte Carlo sampling with this many samples instead of exhaustive enumeration. Applies to both continuous and discrete distributions. |
tmc |
str or None | "diagonal" | Tensor Monte Carlo strategy: "diagonal" or "mixture". Only relevant when num_samples is not None.
|
Returns
| Type | Description |
|---|---|
| callable | The annotated guide callable, wrapped with an infer_config poutine that sets enumeration metadata on each eligible sample site.
|
Usage Patterns
As a Decorator
@config_enumerate
def guide(data):
# All discrete sites with has_enumerate_support will be
# annotated with enumerate="parallel"
assignment = pyro.sample("assignment", dist.Categorical(probs))
...
As a Decorator with Arguments
@config_enumerate(default="sequential", expand=True)
def guide(data):
assignment = pyro.sample("assignment", dist.Categorical(probs))
...
As a Function
guide = config_enumerate(guide, default="parallel")
With Local Monte Carlo Sampling
@config_enumerate(default="parallel", num_samples=10)
def guide(data):
# All sites (including continuous) will be locally sampled
z = pyro.sample("z", dist.Normal(loc, scale))
...
Internal Mechanism
The function delegates to poutine.infer_config with a configuration function built by _config_enumerate. The internal _config_fn (lines 114-131) inspects each sample site and:
- Skips observed sites and subsample sites
- If
num_samplesis set, annotates all unobserved sample sites with{"enumerate": default, "num_samples": num_samples, "expand": expand, "tmc": tmc} - If
num_samplesis None, only annotates sites whose distribution hashas_enumerate_support == Truewith{"enumerate": default, "expand": expand} - Does not overwrite existing per-site
infer={"enumerate": ...}annotations
Validation
The function performs input validation (lines 184-210):
defaultmust be one of"sequential","parallel","flat", orNoneexpandmust beTrueorFalsenum_samplesmust be None or a positive integer"sequential"is not supported withnum_samples(local Monte Carlo requires parallel evaluation)- When
tmc="full"andnum_samples > 1,expandis forced toTrue
Complete Example
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate
from pyro.optim import Adam
# Gaussian Mixture Model
K = 3 # number of components
def model(data):
weights = pyro.sample("weights", dist.Dirichlet(0.5 * torch.ones(K)))
with pyro.plate("components", K):
locs = pyro.sample("locs", dist.Normal(0.0, 10.0))
scales = pyro.sample("scales", dist.LogNormal(0.0, 2.0))
with pyro.plate("data", len(data)):
assignment = pyro.sample("assignment", dist.Categorical(weights))
pyro.sample("obs", dist.Normal(locs[assignment], scales[assignment]), obs=data)
@config_enumerate
def guide(data):
weights_q = pyro.param("weights_q", torch.ones(K),
constraint=dist.constraints.simplex)
pyro.sample("weights", dist.Dirichlet(weights_q))
with pyro.plate("components", K):
loc_q = pyro.param("loc_q", torch.randn(K))
scale_q = pyro.param("scale_q", torch.ones(K),
constraint=dist.constraints.positive)
pyro.sample("locs", dist.Normal(loc_q, scale_q))
pyro.sample("scales", dist.LogNormal(loc_q, scale_q))
with pyro.plate("data", len(data)):
assignment_probs = pyro.param("assignment_probs",
torch.ones(len(data), K) / K,
constraint=dist.constraints.simplex)
pyro.sample("assignment", dist.Categorical(assignment_probs))
# Train with exact enumeration of discrete variables
elbo = TraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(model, guide, Adam({"lr": 0.005}), loss=elbo)
data = torch.randn(100)
for step in range(1000):
loss = svi.step(data)
Source Code Reference
The implementation is located in pyro/infer/enum.py at lines 138-220. The core logic is:
def config_enumerate(
guide=None, default="parallel", expand=False, num_samples=None, tmc="diagonal"
):
# ... validation ...
# Support usage as a decorator:
if guide is None:
return lambda guide: config_enumerate(
guide, default=default, expand=expand, num_samples=num_samples, tmc=tmc
)
return poutine.infer_config(
guide, config_fn=_config_enumerate(default, expand, num_samples, tmc)
)
Related Pages
Implements Principle
Related Implementations
- Pyro_ppl_Pyro_Contract_Tensor_Tree -- The tensor contraction algorithm that processes the enumerated tensor factors produced by sites configured via
config_enumerate. - Pyro_ppl_Pyro_Infer_Discrete -- Uses enumeration configuration to decode discrete posterior after training.
- Pyro_ppl_Pyro_Pyro_Markov -- Declares Markov structure to make enumeration memory-efficient in sequential models.
- Heuristic:Pyro_ppl_Pyro_Enumeration_Plate_Nesting