Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Pyro ppl Pyro Tensor Variable Elimination

From Leeroopedia


Metadata

Field Value
Principle ID Pyro_ppl_Pyro_Tensor_Variable_Elimination
Title Tensor Variable Elimination
Project Pyro (pyro-ppl/pyro)
Domains Discrete_Inference, Tensor_Computation
Implementation Pyro_ppl_Pyro_Contract_Tensor_Tree
Repository https://github.com/pyro-ppl/pyro

Summary

Tensor Variable Elimination (TVE) is the principle of performing exact marginalization of discrete latent variables via tensor contraction in plated factor graphs. It generalizes the classical sum-product algorithm (belief propagation) to handle plate structure -- the repeated, conditionally independent substructures that are ubiquitous in probabilistic models. TVE is the core algorithm that makes TraceEnum_ELBO efficient for models with discrete latent variables.

Motivation

Many probabilistic models contain both discrete latent variables and plate structure (e.g., a mixture model with N data points, each assigned to one of K clusters). Naively marginalizing discrete variables requires summing over all possible configurations, which is exponential in the number of variables. Classical variable elimination algorithms handle this by exploiting conditional independence to factor the sum into smaller local computations.

However, standard variable elimination does not understand plate structure -- the fact that many variables share the same local structure and only differ in their data. TVE extends variable elimination to recognize that:

  • Variables inside plates can be represented as tensors (batches of variables)
  • Contraction within a plate corresponds to product (not sum) operations
  • Contraction across plates follows specific rules based on the nesting structure

This allows TVE to achieve polynomial complexity in problems where naive enumeration would be exponential.

Core Concepts

Plated Factor Graphs

A plated factor graph is a graphical model where:

  • Factors are tensors representing log-probability terms (log p(x) or log q(x))
  • Plates are sets of conditionally independent repetitions (e.g., data points)
  • Sum dimensions correspond to enumerated discrete variables to be marginalized
  • Ordinals are frozensets of plate frames that identify the plate context of each factor

Each factor (tensor) has two types of dimensions:

  • Plate dimensions (batch dimensions): shared across repetitions, contracted via product
  • Enumeration dimensions (sum dimensions): correspond to discrete variable values, contracted via sum

The Tensor Tree

TVE organizes factors into a tensor tree: an OrderedDict mapping ordinals (frozensets of plate frames) to lists of tensors. The tree structure arises from the nesting of plates:

  • The root ordinal frozenset() represents factors outside all plates
  • Leaf ordinals represent factors inside the deepest plate nesting
  • The tree forms a hierarchy based on subset relationships between ordinals

Contraction Algorithm

The algorithm processes the tensor tree bottom-up:

  1. Partition all factors into connected components based on shared sum dimensions
  2. For each connected component, process leaf to root:
    1. At each leaf ordinal, perform a sumproduct contraction to eliminate local sum dimensions
    2. Contract plate dimensions via product when moving from a child ordinal to a parent
    3. Merge the resulting term into the parent ordinal
  3. At the root, perform final contraction to produce the result

Algebraic Rings

TVE is parameterized by an algebraic ring that defines the basic operations:

Ring Sum Operation Product Operation Use Case
LogRing log-sum-exp addition (in log space) Standard ELBO computation
SampleRing log-sum-exp + sampling addition (in log space) Posterior sampling (temperature=1)
MapRing log-max addition (in log space) MAP/Viterbi decoding (temperature=0)

Contraction Path Optimization

TVE uses opt_einsum for planning efficient contraction paths. The shared_intermediates context manager enables caching and reuse of intermediate results across multiple contraction operations, which is particularly beneficial when computing multiple outputs (e.g., ELBO loss + posterior samples).

How It Works

Consider a mixture model with K=3 components and N data points:

Tensor tree:
  ordinal={} -> [log_weights]                    shape: (3,)
  ordinal={data_plate} -> [log_probs]             shape: (3, N)

Sum dimensions: {assignment_dim}   (the discrete enumeration dim)

TVE contracts this as follows:

  1. The leaf ordinal {data_plate} has log_probs with shape (3, N)
  2. Sum out the assignment dimension within the plate via log-sum-exp: result has shape (N,)
  3. Product-contract the plate dimension: result is a scalar (sum of log-marginals)
  4. Combine with the root term log_weights

This achieves O(N * K) complexity instead of O(K^N).

Example

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate
from pyro.optim import Adam

K = 5  # components
N = 1000  # data points

def model(data):
    weights = pyro.sample("weights", dist.Dirichlet(torch.ones(K)))
    with pyro.plate("components", K):
        locs = pyro.sample("locs", dist.Normal(0, 10))
    with pyro.plate("data", N):
        # This discrete variable will be enumerated and marginalized via TVE
        assignment = pyro.sample("assignment", dist.Categorical(weights))
        pyro.sample("obs", dist.Normal(locs[assignment], 1.0), obs=data)

@config_enumerate
def guide(data):
    # Guide samples discrete assignment
    weights_q = pyro.param("weights_q", torch.ones(K),
                           constraint=dist.constraints.simplex)
    pyro.sample("weights", dist.Dirichlet(weights_q))
    with pyro.plate("components", K):
        loc_q = pyro.param("loc_q", torch.randn(K))
        pyro.sample("locs", dist.Normal(loc_q, 1.0))
    with pyro.plate("data", N):
        probs_q = pyro.param("probs_q", torch.ones(N, K) / K,
                             constraint=dist.constraints.simplex)
        pyro.sample("assignment", dist.Categorical(probs_q))

# TraceEnum_ELBO uses TVE internally to marginalize discrete variables
elbo = TraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(model, guide, Adam({"lr": 0.01}), loss=elbo)

Relationship to Other Principles

References

  • Obermeyer, F., Bingham, E., Jankowiak, M., Phan, D., Chen, J.P., "Tensor Variable Elimination for Plated Factor Graphs", 2019. https://arxiv.org/abs/1902.03210
  • Koller, D. & Friedman, N., Probabilistic Graphical Models: Principles and Techniques, 2009.

Related Pages

Implemented By

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment