Heuristic:Pyro ppl Pyro Enumeration Plate Nesting
| Knowledge Sources | |
|---|---|
| Domains | Discrete_Inference, Debugging |
| Last Updated | 2026-02-09 09:00 GMT |
Overview
Critical configuration tips for `max_plate_nesting` when using discrete enumeration, and constraints on enumeration-plate-scale interactions.
Description
When using `TraceEnum_ELBO` for parallel discrete enumeration, Pyro allocates tensor dimensions for enumerated variables to the left of plate dimensions. The `max_plate_nesting` parameter tells Pyro how many rightmost dimensions are reserved for plates, so enumeration dimensions can be placed correctly. Getting this wrong causes silent shape errors or incorrect ELBO computation.
Usage
Apply this heuristic when using TraceEnum_ELBO with parallel enumeration, debugging shape mismatches in enumerated models, or configuring ELBO for models with nested plates. This is essential for any model combining discrete latent variables with `pyro.plate`.
The Insight (Rule of Thumb)
- Action: Set `max_plate_nesting` to the maximum depth of nested `pyro.plate` contexts in your model when using parallel enumeration.
- Value: Default is `float('inf')` which triggers auto-guessing. Auto-guessing works for static models but may fail for dynamic structure.
- Trade-off: Auto-guessing runs the model once to detect plate depth, adding overhead; explicit setting avoids this but requires knowing the model structure.
- Constraint 1: Only scalar `poutine.scale` values are compatible with enumeration; tensor-valued scales raise `ValueError`.
- Constraint 2: Model enumeration must be "no more global" than guide enumeration.
- Tip: If using `vectorize_particles=True` with `num_particles > 1`, max_plate_nesting is automatically incremented by 1.
Reasoning
Pyro's enumeration strategy uses negative tensor dimensions (counting from the right) for plate dimensions and places enumeration dimensions further to the left. If `max_plate_nesting` is too small, enumeration dimensions overlap with plate dimensions, causing incorrect computation. If it's too large, unnecessary broadcasting overhead is introduced.
The auto-guessing mechanism runs the model/guide pair once without enumeration to detect the plate structure. This is optimistic and assumes static model structure. For models with data-dependent plate sizes or conditional plates, explicit specification is safer.
Code evidence for auto-guessing from `pyro/infer/elbo.py:146-186`:
def _guess_max_plate_nesting(self, model, guide, args, kwargs):
"""
Guesses max_plate_nesting by running the (model,guide) pair once
without enumeration. This optimistically assumes static model
structure.
"""
with poutine.block():
guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
model_trace = poutine.trace(
poutine.replay(model, trace=guide_trace)
).get_trace(*args, **kwargs)
# ...
dims = [
frame.dim
for site in sites
for frame in site["cond_indep_stack"]
if frame.vectorized
]
self.max_plate_nesting = -min(dims) if dims else 0
Vectorize particles increment from `pyro/infer/elbo.py:132-133`:
if self.vectorize_particles and self.num_particles > 1:
self.max_plate_nesting += 1
Default value from `pyro/infer/elbo.py:113`:
max_plate_nesting=float("inf"),