Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Pyro ppl Pyro Einsum Example

From Leeroopedia


Property Value
Implementation Type Pattern Doc
Source File examples/einsum.py
Module pyro.ops.contract, pyro.ops.einsum
Pyro Features pyro.ops.contract.einsum, pyro.ops.einsum.adjoint, JIT compilation, various einsum backends (torch_log, torch_map, torch_sample, torch_marginal)
Pattern Profiling plated einsum operations for undirected graphical models

Overview

This file demonstrates how to use Pyro's plated einsum with different backends to compute various quantities from undirected graphical models. The plated einsum extends standard einsum notation with "plate" dimensions that represent independent replication.

Six computational modes are demonstrated:

  • prob: Partition function computation using standard einsum
  • logprob: Log partition function using the torch_log backend (more numerically stable)
  • gradient: Gradient of the log partition function (for training)
  • map: MAP estimates via the adjoint algorithm with torch_map backend
  • sample: Posterior samples via the adjoint algorithm with torch_sample backend
  • marginal: Marginal probabilities via the adjoint algorithm with torch_marginal backend

The adjoint algorithm interface requires four steps: (1) call require_backward() on inputs, (2) run einsum with a nonstandard backend, (3) call _pyro_backward() on outputs, (4) retrieve results from ._pyro_backward_result.

All operations are JIT-compiled for profiling.

Code Reference

def jit_logprob(equation, *operands, **kwargs):
    """Runs einsum to compute the log partition function."""
    key = "logprob", equation, kwargs["plates"]
    if key not in _CACHE:
        def _einsum(*operands):
            return einsum(equation, *operands,
                          backend="pyro.ops.einsum.torch_log", **kwargs)
        _CACHE[key] = torch.jit.trace(_einsum, operands, check_trace=False)
    return _CACHE[key](*operands)

def _jit_adjoint(equation, *operands, **kwargs):
    """Runs einsum in forward-backward mode using pyro.ops.adjoint."""
    def _forward_backward(*operands):
        for operand in operands:
            require_backward(operand)
        results = einsum(equation, *operands, backend=backend, **kwargs)
        for result in results:
            result._pyro_backward()
        return tuple(x._pyro_backward_result for x in operands)

I/O Contract

Parameter Type Description
-e / --equation str Einsum equation (default: "a,abi,bcij,adj,deij->")
-p / --plates str Plate dimensions (default: "ij")
-d / --dim-size int Non-plate dimension size (default: 32)
-s / --max-plate-size int Maximum plate dimension size (default: 32)
-m / --method str Method: prob, logprob, gradient, marginal, map, sample, or all
-n / --iters int Profiling iterations (default: 10)

Output:

  • Time per iteration (in ms) for each plate size

Usage Examples

# Profile all methods on default equation
# python einsum.py -m all

# Profile specific equation with larger plates
# python einsum.py -e "ab,bc,cd->" -p "" -d 64 -s 128 -m logprob

# Profile on GPU
# python einsum.py --cuda -m gradient

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment