Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Pyro ppl Pyro Contract Tensor Tree

From Leeroopedia


Metadata

Field Value
Implementation ID Pyro_ppl_Pyro_Contract_Tensor_Tree
Title contract_tensor_tree
Project Pyro (pyro-ppl/pyro)
File pyro/ops/contract.py, Lines 163-203
Implements Pyro_ppl_Pyro_Tensor_Variable_Elimination
Repository https://github.com/pyro-ppl/pyro

Summary

contract_tensor_tree is the core function that implements Tensor Variable Elimination (TVE) for plated factor graphs. It contracts out sum dimensions in a tree of tensors organized by plate ordinals, using message passing between connected components. This is the function that powers TraceEnum_ELBO's exact marginalization of discrete variables.

Signature

def contract_tensor_tree(tensor_tree, sum_dims, cache=None, ring=None)

Also related:

def contract_to_tensor(tensor_tree, sum_dims, target_ordinal=None,
                       target_dims=None, cache=None, ring=None)

Import

from pyro.ops.contract import contract_tensor_tree

Parameters

Parameter Type Default Description
tensor_tree OrderedDict required A dictionary mapping ordinals (frozenset of plate frames) to lists of tensors. Each tensor represents a log-probability factor annotated with _pyro_dims.
sum_dims set required The complete set of sum-contraction dimensions (characters) to marginalize out. These correspond to enumerated discrete variables.
cache dict or None None An optional opt_einsum.shared_intermediates cache for sharing intermediate computations across multiple contraction calls.
ring Ring or None None An optional algebraic ring defining the tensor operations. Defaults to LogRing(cache) if not provided.

Returns

Type Description
OrderedDict A contracted version of the tensor tree, where each ordinal maps to a list of tensors with sum dimensions eliminated. Each connected component is contracted to a single tensor at the minimum (root) ordinal.

Internal Mechanism

Step 1: Ring Selection (lines 185-186)

If no ring is provided, defaults to LogRing(cache), which performs operations in log space (log-sum-exp for sums, addition for products).

Step 2: Term Indexing (lines 188-189)

Creates a reverse mapping from each tensor to its ordinal, and flattens all tensors into a single list.

ordinals = {term: t for t, terms in tensor_tree.items() for term in terms}
all_terms = [term for terms in tensor_tree.values() for term in terms]

Step 3: Partition into Connected Components (line 193)

Calls _partition_terms(ring, all_terms, sum_dims) which:

  1. Constructs a bipartite graph between tensors and sum dimensions
  2. Finds connected components via BFS
  3. Returns pairs of (component_terms, component_dims)

This is critical for efficiency: independent groups of variables can be contracted separately, avoiding unnecessary broadcasting.

Step 4: Contract Each Component (lines 194-200)

For each connected component:

  1. Rebuilds the local tensor tree for the component
  2. Calls _contract_component(ring, component, dims, set())
  3. The component contraction (lines 79-160) performs bottom-up message passing:
    • Finds the leaf (deepest plate nesting) ordinal
    • Performs sumproduct contraction to eliminate local sum dimensions
    • Contracts plate dimensions via product when moving to parent ordinals
    • Continues until a single tensor remains at the root ordinal

Step 5: Reassemble Result (lines 199-202)

contracted_tree = OrderedDict()
for terms, dims in _partition_terms(ring, all_terms, sum_dims):
    component = OrderedDict()
    for term in terms:
        component.setdefault(ordinals[term], []).append(term)
    ordinal, term = _contract_component(ring, component, dims, set())
    contracted_tree.setdefault(ordinal, []).append(term)
return contracted_tree

The _partition_terms Helper (lines 38-76)

This function partitions terms into independent connected components by constructing a bipartite graph between tensors and their sum dimensions:

def _partition_terms(ring, terms, dims):
    # Build bipartite graph: tensor <-> sum_dim
    neighbors = OrderedDict(
        [(t, []) for t in terms] + [(d, []) for d in sorted(dims)]
    )
    for term in terms:
        for dim in term._pyro_dims:
            if dim in dims:
                neighbors[term].append(dim)
                neighbors[dim].append(term)
    # Find connected components via BFS
    ...
    return components  # list of (component_terms, component_dims)

The _contract_component Helper (lines 79-160)

This function performs the core bottom-up message passing for a single connected component:

  1. Groups sum dims by ordinal (which plate context they belong to)
  2. Iteratively selects the deepest leaf ordinal
  3. At each leaf: performs sumproduct to eliminate local sum dims
  4. Contracts plate dimensions when moving to parent
  5. Handles global-local decomposition for target dims that must be preserved

contract_to_tensor (lines 205-273)

A related function that contracts the entire tensor tree down to a single tensor at a specified target ordinal, optionally preserving specified target dimensions:

def contract_to_tensor(tensor_tree, sum_dims, target_ordinal=None,
                       target_dims=None, cache=None, ring=None):
    # Contracts tensor tree to a single result tensor
    ...

This is used by TraceEnum_ELBO when computing the final scalar loss value.

Ring Implementations

The ring parameter controls the algebraic semantics of contraction:

Ring Module sumproduct() product() Use Case
LogRing pyro.ops.rings Log-sum-exp over sum dims, sum over terms Sum in log-space, then logsumexp over plate dims ELBO computation
SampleRing pyro.ops.rings Like LogRing forward, but records info for backward sampling Same as LogRing Posterior sampling
MapRing pyro.ops.rings Log-max instead of log-sum-exp Same as LogRing MAP/Viterbi decoding

Complete Example

from collections import OrderedDict
import torch
from pyro.ops.contract import contract_tensor_tree

# Simulate a simple tensor tree:
# Two factors: one at root (no plates), one inside a "data" plate
root_ordinal = frozenset()
data_ordinal = frozenset(["data"])

# Factor outside plates: log p(z) with enumerated z dim 'a'
factor1 = torch.randn(3)  # 3 possible values of z
factor1._pyro_dims = ('a',)

# Factor inside data plate: log p(x_n | z) with enum dim 'a'
factor2 = torch.randn(3, 10)  # 3 z-values x 10 data points
factor2._pyro_dims = ('a', 'data')

tensor_tree = OrderedDict([
    (root_ordinal, [factor1]),
    (data_ordinal, [factor2]),
])
sum_dims = {'a'}  # marginalize out the discrete variable z

result = contract_tensor_tree(tensor_tree, sum_dims)
# result contains contracted tensors with 'a' eliminated

Related Pages

Implements Principle

Related Implementations

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment