Implementation:Pyro ppl Pyro Rings
| Property | Value |
|---|---|
| Module | pyro.ops.rings
|
| Source | pyro/ops/rings.py |
| Lines | 338 |
| Classes | Ring, LinearRing, LogRing, MapRing, SampleRing, MarginalRing
|
| Dependencies | torch, pyro.ops.einsum, pyro.util
|
Overview
This module implements a hierarchy of tensor ring classes used for sum-product computations in Pyro's variable elimination and belief propagation algorithms. Each ring defines algebraic operations (sum-product, product, inverse, broadcast) that operate on "packed" tensors, which are tensors annotated with a ._pyro_dims attribute containing a string of dimension names.
The ring hierarchy supports different semiring semantics:
- LinearRing: Standard sum-product in linear space.
- LogRing: Sum-product in log space (logsumexp for sum, sum for product).
- MapRing: Forward max-sum, backward argmax (MAP inference).
- SampleRing: Forward sum-product, backward sampling.
- MarginalRing: Forward sum-product, backward marginalization using inclusion-exclusion.
The abstract Ring base class defines the interface with caching support for use with opt_einsum.shared_intermediates.
Code Reference
Class: Ring (Abstract Base)
Defines the ring interface with a cache for memoizing intermediate results.
sumproduct(terms, dims): Abstract. Multiply terms, then sum-contract out dims.product(term, ordinal): Abstract. Product-contract along plate dimensions in ordinal.broadcast(term, ordinal): Expand a term along plate dimensions in ordinal.inv(term): Abstract. Compute reciprocal of a term for inclusion-exclusion.global_local(term, dims, ordinal): Computes global and local terms for tensor message passing using inclusion-exclusion:term / sum(term, dims) * product(sum(term, dims), ordinal).
Class: LinearRing
Standard linear-space ring using torch backend for einsum.
sumproduct: Usescontract(equation, *terms, backend="torch").product: Applies.prod(pos)along plate dimensions.inv: Computes.reciprocal()clamped to avoid NaN from inf/inf.
Class: LogRing
Log-space ring using pyro.ops.einsum.torch_log backend.
sumproduct: Uses logsumexp-based einsum.product: Applies.sum(pos)(sum in log space = product in linear space).inv: Negation in log space (reciprocal in linear space), clamped.
Class: MapRing
Extends LogRing for MAP inference with pyro.ops.einsum.torch_map backend. Attaches backward handlers (_SampleProductBackward) for argmax backward pass.
Class: SampleRing
Extends LogRing for sampling with pyro.ops.einsum.torch_sample backend. Attaches backward handlers for stochastic backward sampling.
Class: MarginalRing
Extends LogRing for marginal computation with pyro.ops.einsum.torch_marginal backend. Uses _MarginalProductBackward for inclusion-exclusion-based backward marginalization.
BACKEND_TO_RING Dictionary
Maps backend string names to ring classes for dispatch.
I/O Contract
| Method | Input | Output |
|---|---|---|
sumproduct(terms, dims) |
List of packed tensors, iterable of dim chars | Packed tensor with dims contracted out |
product(term, ordinal) |
Packed tensor, frozenset of plate dim chars | Packed tensor with plates contracted |
broadcast(term, ordinal) |
Packed tensor, frozenset of dim chars | Packed tensor expanded to include all dims in ordinal |
inv(term) |
Packed tensor | Packed tensor (reciprocal/negation) |
global_local(term, dims, ordinal) |
Packed tensor, dims, ordinal | Tuple of (global_part, local_part) |
Usage Examples
import torch
from pyro.ops.rings import LogRing, BACKEND_TO_RING
# Create a ring with caching
cache = {}
ring = LogRing(cache=cache)
# Packed tensors have ._pyro_dims attributes
x = torch.randn(3, 4)
x._pyro_dims = "ab"
y = torch.randn(4, 5)
y._pyro_dims = "bc"
# Sum-product contraction
result = ring.sumproduct([x, y], dims="b")
print(result._pyro_dims) # "ac"
# Look up a ring class by backend name
RingClass = BACKEND_TO_RING["pyro.ops.einsum.torch_log"]
assert RingClass is LogRing
Related Pages
- Pyro_ppl_Pyro_PackedTensor -- Pack/unpack utilities for dimension-named tensors
- Pyro_ppl_Pyro_Gaussian -- Gaussian operations that use similar contraction patterns