Implementation:Pyro ppl Pyro Einsum Example
Appearance
| Property | Value |
|---|---|
| Implementation Type | Pattern Doc |
| Source File | examples/einsum.py
|
| Module | pyro.ops.contract, pyro.ops.einsum |
| Pyro Features | pyro.ops.contract.einsum, pyro.ops.einsum.adjoint, JIT compilation, various einsum backends (torch_log, torch_map, torch_sample, torch_marginal)
|
| Pattern | Profiling plated einsum operations for undirected graphical models |
Overview
This file demonstrates how to use Pyro's plated einsum with different backends to compute various quantities from undirected graphical models. The plated einsum extends standard einsum notation with "plate" dimensions that represent independent replication.
Six computational modes are demonstrated:
- prob: Partition function computation using standard einsum
- logprob: Log partition function using the
torch_logbackend (more numerically stable) - gradient: Gradient of the log partition function (for training)
- map: MAP estimates via the adjoint algorithm with
torch_mapbackend - sample: Posterior samples via the adjoint algorithm with
torch_samplebackend - marginal: Marginal probabilities via the adjoint algorithm with
torch_marginalbackend
The adjoint algorithm interface requires four steps: (1) call require_backward() on inputs, (2) run einsum with a nonstandard backend, (3) call _pyro_backward() on outputs, (4) retrieve results from ._pyro_backward_result.
All operations are JIT-compiled for profiling.
Code Reference
def jit_logprob(equation, *operands, **kwargs):
"""Runs einsum to compute the log partition function."""
key = "logprob", equation, kwargs["plates"]
if key not in _CACHE:
def _einsum(*operands):
return einsum(equation, *operands,
backend="pyro.ops.einsum.torch_log", **kwargs)
_CACHE[key] = torch.jit.trace(_einsum, operands, check_trace=False)
return _CACHE[key](*operands)
def _jit_adjoint(equation, *operands, **kwargs):
"""Runs einsum in forward-backward mode using pyro.ops.adjoint."""
def _forward_backward(*operands):
for operand in operands:
require_backward(operand)
results = einsum(equation, *operands, backend=backend, **kwargs)
for result in results:
result._pyro_backward()
return tuple(x._pyro_backward_result for x in operands)
I/O Contract
| Parameter | Type | Description |
|---|---|---|
-e / --equation |
str |
Einsum equation (default: "a,abi,bcij,adj,deij->") |
-p / --plates |
str |
Plate dimensions (default: "ij") |
-d / --dim-size |
int |
Non-plate dimension size (default: 32) |
-s / --max-plate-size |
int |
Maximum plate dimension size (default: 32) |
-m / --method |
str |
Method: prob, logprob, gradient, marginal, map, sample, or all |
-n / --iters |
int |
Profiling iterations (default: 10) |
Output:
- Time per iteration (in ms) for each plate size
Usage Examples
# Profile all methods on default equation
# python einsum.py -m all
# Profile specific equation with larger plates
# python einsum.py -e "ab,bc,cd->" -p "" -d 64 -s 128 -m logprob
# Profile on GPU
# python einsum.py --cuda -m gradient
Related Pages
- Pyro_ppl_Pyro_Funsor_HMM - Uses variable elimination internally via similar mechanisms
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment