Implementation:Pyro ppl Pyro Contract Tensor Tree
Metadata
| Field | Value |
|---|---|
| Implementation ID | Pyro_ppl_Pyro_Contract_Tensor_Tree |
| Title | contract_tensor_tree |
| Project | Pyro (pyro-ppl/pyro) |
| File | pyro/ops/contract.py, Lines 163-203
|
| Implements | Pyro_ppl_Pyro_Tensor_Variable_Elimination |
| Repository | https://github.com/pyro-ppl/pyro |
Summary
contract_tensor_tree is the core function that implements Tensor Variable Elimination (TVE) for plated factor graphs. It contracts out sum dimensions in a tree of tensors organized by plate ordinals, using message passing between connected components. This is the function that powers TraceEnum_ELBO's exact marginalization of discrete variables.
Signature
def contract_tensor_tree(tensor_tree, sum_dims, cache=None, ring=None)
Also related:
def contract_to_tensor(tensor_tree, sum_dims, target_ordinal=None,
target_dims=None, cache=None, ring=None)
Import
from pyro.ops.contract import contract_tensor_tree
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
tensor_tree |
OrderedDict | required | A dictionary mapping ordinals (frozenset of plate frames) to lists of tensors. Each tensor represents a log-probability factor annotated with _pyro_dims.
|
sum_dims |
set | required | The complete set of sum-contraction dimensions (characters) to marginalize out. These correspond to enumerated discrete variables. |
cache |
dict or None | None | An optional opt_einsum.shared_intermediates cache for sharing intermediate computations across multiple contraction calls.
|
ring |
Ring or None | None | An optional algebraic ring defining the tensor operations. Defaults to LogRing(cache) if not provided.
|
Returns
| Type | Description |
|---|---|
| OrderedDict | A contracted version of the tensor tree, where each ordinal maps to a list of tensors with sum dimensions eliminated. Each connected component is contracted to a single tensor at the minimum (root) ordinal. |
Internal Mechanism
Step 1: Ring Selection (lines 185-186)
If no ring is provided, defaults to LogRing(cache), which performs operations in log space (log-sum-exp for sums, addition for products).
Step 2: Term Indexing (lines 188-189)
Creates a reverse mapping from each tensor to its ordinal, and flattens all tensors into a single list.
ordinals = {term: t for t, terms in tensor_tree.items() for term in terms}
all_terms = [term for terms in tensor_tree.values() for term in terms]
Step 3: Partition into Connected Components (line 193)
Calls _partition_terms(ring, all_terms, sum_dims) which:
- Constructs a bipartite graph between tensors and sum dimensions
- Finds connected components via BFS
- Returns pairs of
(component_terms, component_dims)
This is critical for efficiency: independent groups of variables can be contracted separately, avoiding unnecessary broadcasting.
Step 4: Contract Each Component (lines 194-200)
For each connected component:
- Rebuilds the local tensor tree for the component
- Calls
_contract_component(ring, component, dims, set()) - The component contraction (lines 79-160) performs bottom-up message passing:
- Finds the leaf (deepest plate nesting) ordinal
- Performs sumproduct contraction to eliminate local sum dimensions
- Contracts plate dimensions via product when moving to parent ordinals
- Continues until a single tensor remains at the root ordinal
Step 5: Reassemble Result (lines 199-202)
contracted_tree = OrderedDict()
for terms, dims in _partition_terms(ring, all_terms, sum_dims):
component = OrderedDict()
for term in terms:
component.setdefault(ordinals[term], []).append(term)
ordinal, term = _contract_component(ring, component, dims, set())
contracted_tree.setdefault(ordinal, []).append(term)
return contracted_tree
The _partition_terms Helper (lines 38-76)
This function partitions terms into independent connected components by constructing a bipartite graph between tensors and their sum dimensions:
def _partition_terms(ring, terms, dims):
# Build bipartite graph: tensor <-> sum_dim
neighbors = OrderedDict(
[(t, []) for t in terms] + [(d, []) for d in sorted(dims)]
)
for term in terms:
for dim in term._pyro_dims:
if dim in dims:
neighbors[term].append(dim)
neighbors[dim].append(term)
# Find connected components via BFS
...
return components # list of (component_terms, component_dims)
The _contract_component Helper (lines 79-160)
This function performs the core bottom-up message passing for a single connected component:
- Groups sum dims by ordinal (which plate context they belong to)
- Iteratively selects the deepest leaf ordinal
- At each leaf: performs sumproduct to eliminate local sum dims
- Contracts plate dimensions when moving to parent
- Handles global-local decomposition for target dims that must be preserved
contract_to_tensor (lines 205-273)
A related function that contracts the entire tensor tree down to a single tensor at a specified target ordinal, optionally preserving specified target dimensions:
def contract_to_tensor(tensor_tree, sum_dims, target_ordinal=None,
target_dims=None, cache=None, ring=None):
# Contracts tensor tree to a single result tensor
...
This is used by TraceEnum_ELBO when computing the final scalar loss value.
Ring Implementations
The ring parameter controls the algebraic semantics of contraction:
| Ring | Module | sumproduct() | product() | Use Case |
|---|---|---|---|---|
LogRing |
pyro.ops.rings |
Log-sum-exp over sum dims, sum over terms | Sum in log-space, then logsumexp over plate dims | ELBO computation |
SampleRing |
pyro.ops.rings |
Like LogRing forward, but records info for backward sampling | Same as LogRing | Posterior sampling |
MapRing |
pyro.ops.rings |
Log-max instead of log-sum-exp | Same as LogRing | MAP/Viterbi decoding |
Complete Example
from collections import OrderedDict
import torch
from pyro.ops.contract import contract_tensor_tree
# Simulate a simple tensor tree:
# Two factors: one at root (no plates), one inside a "data" plate
root_ordinal = frozenset()
data_ordinal = frozenset(["data"])
# Factor outside plates: log p(z) with enumerated z dim 'a'
factor1 = torch.randn(3) # 3 possible values of z
factor1._pyro_dims = ('a',)
# Factor inside data plate: log p(x_n | z) with enum dim 'a'
factor2 = torch.randn(3, 10) # 3 z-values x 10 data points
factor2._pyro_dims = ('a', 'data')
tensor_tree = OrderedDict([
(root_ordinal, [factor1]),
(data_ordinal, [factor2]),
])
sum_dims = {'a'} # marginalize out the discrete variable z
result = contract_tensor_tree(tensor_tree, sum_dims)
# result contains contracted tensors with 'a' eliminated
Related Pages
Implements Principle
Related Implementations
- Pyro_ppl_Pyro_Config_Enumerate -- Produces the annotated sample sites whose log-probabilities form the tensor tree input.
- Pyro_ppl_Pyro_Infer_Discrete -- Calls
contract_tensor_treewith SampleRing or MapRing for posterior decoding. - Pyro_ppl_Pyro_Pyro_Markov -- Provides the Markov scoping that determines which sum dims can be contracted at each step in sequential models.
- Environment:Pyro_ppl_Pyro_Funsor_Backend