Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro Rings

From Leeroopedia


Property Value
Module pyro.ops.rings
Source pyro/ops/rings.py
Lines 338
Classes Ring, LinearRing, LogRing, MapRing, SampleRing, MarginalRing
Dependencies torch, pyro.ops.einsum, pyro.util

Overview

This module implements a hierarchy of tensor ring classes used for sum-product computations in Pyro's variable elimination and belief propagation algorithms. Each ring defines algebraic operations (sum-product, product, inverse, broadcast) that operate on "packed" tensors, which are tensors annotated with a ._pyro_dims attribute containing a string of dimension names.

The ring hierarchy supports different semiring semantics:

  • LinearRing: Standard sum-product in linear space.
  • LogRing: Sum-product in log space (logsumexp for sum, sum for product).
  • MapRing: Forward max-sum, backward argmax (MAP inference).
  • SampleRing: Forward sum-product, backward sampling.
  • MarginalRing: Forward sum-product, backward marginalization using inclusion-exclusion.

The abstract Ring base class defines the interface with caching support for use with opt_einsum.shared_intermediates.

Code Reference

Class: Ring (Abstract Base)

Defines the ring interface with a cache for memoizing intermediate results.

  • sumproduct(terms, dims): Abstract. Multiply terms, then sum-contract out dims.
  • product(term, ordinal): Abstract. Product-contract along plate dimensions in ordinal.
  • broadcast(term, ordinal): Expand a term along plate dimensions in ordinal.
  • inv(term): Abstract. Compute reciprocal of a term for inclusion-exclusion.
  • global_local(term, dims, ordinal): Computes global and local terms for tensor message passing using inclusion-exclusion: term / sum(term, dims) * product(sum(term, dims), ordinal).

Class: LinearRing

Standard linear-space ring using torch backend for einsum.

  • sumproduct: Uses contract(equation, *terms, backend="torch").
  • product: Applies .prod(pos) along plate dimensions.
  • inv: Computes .reciprocal() clamped to avoid NaN from inf/inf.

Class: LogRing

Log-space ring using pyro.ops.einsum.torch_log backend.

  • sumproduct: Uses logsumexp-based einsum.
  • product: Applies .sum(pos) (sum in log space = product in linear space).
  • inv: Negation in log space (reciprocal in linear space), clamped.

Class: MapRing

Extends LogRing for MAP inference with pyro.ops.einsum.torch_map backend. Attaches backward handlers (_SampleProductBackward) for argmax backward pass.

Class: SampleRing

Extends LogRing for sampling with pyro.ops.einsum.torch_sample backend. Attaches backward handlers for stochastic backward sampling.

Class: MarginalRing

Extends LogRing for marginal computation with pyro.ops.einsum.torch_marginal backend. Uses _MarginalProductBackward for inclusion-exclusion-based backward marginalization.

BACKEND_TO_RING Dictionary

Maps backend string names to ring classes for dispatch.

I/O Contract

Method Input Output
sumproduct(terms, dims) List of packed tensors, iterable of dim chars Packed tensor with dims contracted out
product(term, ordinal) Packed tensor, frozenset of plate dim chars Packed tensor with plates contracted
broadcast(term, ordinal) Packed tensor, frozenset of dim chars Packed tensor expanded to include all dims in ordinal
inv(term) Packed tensor Packed tensor (reciprocal/negation)
global_local(term, dims, ordinal) Packed tensor, dims, ordinal Tuple of (global_part, local_part)

Usage Examples

import torch
from pyro.ops.rings import LogRing, BACKEND_TO_RING

# Create a ring with caching
cache = {}
ring = LogRing(cache=cache)

# Packed tensors have ._pyro_dims attributes
x = torch.randn(3, 4)
x._pyro_dims = "ab"

y = torch.randn(4, 5)
y._pyro_dims = "bc"

# Sum-product contraction
result = ring.sumproduct([x, y], dims="b")
print(result._pyro_dims)  # "ac"

# Look up a ring class by backend name
RingClass = BACKEND_TO_RING["pyro.ops.einsum.torch_log"]
assert RingClass is LogRing

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment