Principle:Pyro ppl Pyro Tensor Variable Elimination
Metadata
| Field | Value |
|---|---|
| Principle ID | Pyro_ppl_Pyro_Tensor_Variable_Elimination |
| Title | Tensor Variable Elimination |
| Project | Pyro (pyro-ppl/pyro) |
| Domains | Discrete_Inference, Tensor_Computation |
| Implementation | Pyro_ppl_Pyro_Contract_Tensor_Tree |
| Repository | https://github.com/pyro-ppl/pyro |
Summary
Tensor Variable Elimination (TVE) is the principle of performing exact marginalization of discrete latent variables via tensor contraction in plated factor graphs. It generalizes the classical sum-product algorithm (belief propagation) to handle plate structure -- the repeated, conditionally independent substructures that are ubiquitous in probabilistic models. TVE is the core algorithm that makes TraceEnum_ELBO efficient for models with discrete latent variables.
Motivation
Many probabilistic models contain both discrete latent variables and plate structure (e.g., a mixture model with N data points, each assigned to one of K clusters). Naively marginalizing discrete variables requires summing over all possible configurations, which is exponential in the number of variables. Classical variable elimination algorithms handle this by exploiting conditional independence to factor the sum into smaller local computations.
However, standard variable elimination does not understand plate structure -- the fact that many variables share the same local structure and only differ in their data. TVE extends variable elimination to recognize that:
- Variables inside plates can be represented as tensors (batches of variables)
- Contraction within a plate corresponds to product (not sum) operations
- Contraction across plates follows specific rules based on the nesting structure
This allows TVE to achieve polynomial complexity in problems where naive enumeration would be exponential.
Core Concepts
Plated Factor Graphs
A plated factor graph is a graphical model where:
- Factors are tensors representing log-probability terms (log p(x) or log q(x))
- Plates are sets of conditionally independent repetitions (e.g., data points)
- Sum dimensions correspond to enumerated discrete variables to be marginalized
- Ordinals are frozensets of plate frames that identify the plate context of each factor
Each factor (tensor) has two types of dimensions:
- Plate dimensions (batch dimensions): shared across repetitions, contracted via product
- Enumeration dimensions (sum dimensions): correspond to discrete variable values, contracted via sum
The Tensor Tree
TVE organizes factors into a tensor tree: an OrderedDict mapping ordinals (frozensets of plate frames) to lists of tensors. The tree structure arises from the nesting of plates:
- The root ordinal
frozenset()represents factors outside all plates - Leaf ordinals represent factors inside the deepest plate nesting
- The tree forms a hierarchy based on subset relationships between ordinals
Contraction Algorithm
The algorithm processes the tensor tree bottom-up:
- Partition all factors into connected components based on shared sum dimensions
- For each connected component, process leaf to root:
- At each leaf ordinal, perform a sumproduct contraction to eliminate local sum dimensions
- Contract plate dimensions via product when moving from a child ordinal to a parent
- Merge the resulting term into the parent ordinal
- At the root, perform final contraction to produce the result
Algebraic Rings
TVE is parameterized by an algebraic ring that defines the basic operations:
| Ring | Sum Operation | Product Operation | Use Case |
|---|---|---|---|
| LogRing | log-sum-exp | addition (in log space) | Standard ELBO computation |
| SampleRing | log-sum-exp + sampling | addition (in log space) | Posterior sampling (temperature=1) |
| MapRing | log-max | addition (in log space) | MAP/Viterbi decoding (temperature=0) |
Contraction Path Optimization
TVE uses opt_einsum for planning efficient contraction paths. The shared_intermediates context manager enables caching and reuse of intermediate results across multiple contraction operations, which is particularly beneficial when computing multiple outputs (e.g., ELBO loss + posterior samples).
How It Works
Consider a mixture model with K=3 components and N data points:
Tensor tree:
ordinal={} -> [log_weights] shape: (3,)
ordinal={data_plate} -> [log_probs] shape: (3, N)
Sum dimensions: {assignment_dim} (the discrete enumeration dim)
TVE contracts this as follows:
- The leaf ordinal
{data_plate}haslog_probswith shape(3, N) - Sum out the assignment dimension within the plate via log-sum-exp: result has shape
(N,) - Product-contract the plate dimension: result is a scalar (sum of log-marginals)
- Combine with the root term
log_weights
This achieves O(N * K) complexity instead of O(K^N).
Example
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate
from pyro.optim import Adam
K = 5 # components
N = 1000 # data points
def model(data):
weights = pyro.sample("weights", dist.Dirichlet(torch.ones(K)))
with pyro.plate("components", K):
locs = pyro.sample("locs", dist.Normal(0, 10))
with pyro.plate("data", N):
# This discrete variable will be enumerated and marginalized via TVE
assignment = pyro.sample("assignment", dist.Categorical(weights))
pyro.sample("obs", dist.Normal(locs[assignment], 1.0), obs=data)
@config_enumerate
def guide(data):
# Guide samples discrete assignment
weights_q = pyro.param("weights_q", torch.ones(K),
constraint=dist.constraints.simplex)
pyro.sample("weights", dist.Dirichlet(weights_q))
with pyro.plate("components", K):
loc_q = pyro.param("loc_q", torch.randn(K))
pyro.sample("locs", dist.Normal(loc_q, 1.0))
with pyro.plate("data", N):
probs_q = pyro.param("probs_q", torch.ones(N, K) / K,
constraint=dist.constraints.simplex)
pyro.sample("assignment", dist.Categorical(probs_q))
# TraceEnum_ELBO uses TVE internally to marginalize discrete variables
elbo = TraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(model, guide, Adam({"lr": 0.01}), loss=elbo)
Relationship to Other Principles
- Pyro_ppl_Pyro_Enumeration_Configuration -- TVE operates on the enumerated tensors produced by
config_enumerateandEnumMessenger. - Pyro_ppl_Pyro_Markov_Dependency -- For sequential models, Markov annotations allow TVE to contract incrementally along the sequence, avoiding the need to hold the entire sequence in memory.
- Pyro_ppl_Pyro_Discrete_Posterior_Decoding --
infer_discreteuses TVE with different rings (SampleRing, MapRing) for posterior inference.
References
- Obermeyer, F., Bingham, E., Jankowiak, M., Phan, D., Chen, J.P., "Tensor Variable Elimination for Plated Factor Graphs", 2019. https://arxiv.org/abs/1902.03210
- Koller, D. & Friedman, N., Probabilistic Graphical Models: Principles and Techniques, 2009.