Implementation:Pyro ppl Pyro Infer Utilities
Appearance
Overview
The util module (Template:Code) provides a collection of utility functions and classes used throughout Pyro's inference subsystem. It includes validation control, tensor helper functions, plate stack management, the MultiFrameTensor container for aggregating costs across plate contexts, the Dice operator implementation for computing expectations with enumerated variables, and the CloneMixin for dataclass cloning.
Key categories of functionality:
- Validation control -- Template:Code, Template:Code, and the Template:Code context manager for toggling inference validation checks.
- Tensor helpers -- Template:Code, Template:Code, Template:Code, Template:Code, and Template:Code for safe operations that handle both tensors and plain numbers.
- Plate utilities -- Template:Code and Template:Code for extracting plate structure from traces.
- MultiFrameTensor -- A dict-like container that accumulates tensors across different plate contexts and can sum them to a target plate stack.
- Dice -- An implementation of the DiCE (Infinitely Differentiable Monte Carlo Estimator) operator compatible with Pyro's plate and enumeration features.
- CloneMixin -- A mixin for cloning Template:Code instances with tensor fields.
Code Reference
File: Template:Code
Validation Functions
| Function | Description |
|---|---|
| Template:Code | Globally enable or disable inference validation. |
| Template:Code | Returns whether inference validation is currently enabled. |
| Template:Code | Context manager for temporarily enabling/disabling validation. |
Tensor Helper Functions
| Function | Description |
|---|---|
| Template:Code | Like Template:Code but also works with plain numbers. |
| Template:Code | Like Template:Code but is a no-op for numbers or tensors without Template:Code. |
| Template:Code | Like Template:Code but also works with numbers via Template:Code. |
| Template:Code | Like Template:Code but only sums over dims that exist. |
| Template:Code | Sets gradients of a list of tensors to zero in place. |
Plate Utility Functions
| Function | Description |
|---|---|
| Template:Code | Builds a dict mapping site name to a list of vectorized plate frames. Used by Template:Code and Template:Code. |
| Template:Code | Returns a sorted list of plate dims that are not shared by all sites. Used for identifying which dims to sum over. |
MultiFrameTensor
| Method | Description |
|---|---|
| Template:Code | Initializes with optional Template:Code pairs. |
| Template:Code | Adds Template:Code pairs. Tensors with the same plate frames are summed together. |
| Template:Code | Sums all stored tensors down to the given target plate frames, contracting over extra plate dimensions. |
Dice
| Method | Description |
|---|---|
| Template:Code | Initializes from a guide trace and an ordering dict mapping site names to ordinal values. |
| Template:Code | Computes a differentiable expected cost using sum-product contraction, summing over costs at given ordinals. |
Other
| Function/Class | Description |
|---|---|
| Template:Code | Computes DiCE log-factors for a single site, handling enumeration and Monte Carlo sampling. |
| Template:Code | Validates that a guide site is fully reparameterized; raises Template:Code otherwise. |
| Template:Code | Computes log probability sum from a trace while preserving indexing over a specified plate symbol. |
| Template:Code | Mixin adding a Template:Code method to dataclass types with tensor fields. |
I/O Contract
MultiFrameTensor
Inputs (via add):
- Pairs of Template:Code where Template:Code is a sequence of Template:Code objects and Template:Code is a Template:Code.
Output (via sum_to):
- Template:Code or Template:Code -- The accumulated cost summed to the target plate frames.
Dice.compute_expectation
Inputs:
- Template:Code -- Maps ordinals to lists of cost tensors with Template:Code attributes.
Output:
- Template:Code or Template:Code -- A differentiable scalar expected cost.
Usage Examples
Validation Control
from pyro.infer.util import enable_validation, is_validation_enabled, validation_enabled
# Check current state
print(is_validation_enabled()) # True by default in debug mode
# Temporarily disable validation for speed
with validation_enabled(False):
# run inference without validation checks
pass
# Re-enable globally
enable_validation(True)
Tensor Helpers
from pyro.infer.util import torch_item, torch_backward, zero_grads
# torch_item works on both tensors and numbers
loss_tensor = torch.tensor(3.14)
loss_number = 3.14
print(torch_item(loss_tensor)) # 3.14
print(torch_item(loss_number)) # 3.14
# torch_backward is a safe no-op for non-differentiable values
torch_backward(loss_number) # no-op
torch_backward(loss_tensor) # no-op (no grad_fn)
# Zero out gradients
params = [pyro.param("w")]
zero_grads([p.unconstrained() for p in params])
Using MultiFrameTensor
from pyro.infer.util import MultiFrameTensor
downstream_cost = MultiFrameTensor()
for site in downstream_nodes:
downstream_cost.add((site["cond_indep_stack"], site["log_prob"]))
summed = downstream_cost.sum_to(target_site["cond_indep_stack"])
Related Pages
- Pyro_ppl_Pyro_TraceGraph_ELBO -- Primary consumer of MultiFrameTensor and downstream cost computation
- Pyro_ppl_Pyro_TraceTMC_ELBO -- Uses Template:Code for DiCE-based gradient estimation
- Pyro_ppl_Pyro_TraceTailAdaptive_ELBO -- Uses Template:Code for validation
- Pyro_ppl_Pyro_SVGD -- Uses Template:Code for parameter gradient management
- Pyro_ppl_Pyro_Abstract_Infer -- Base inference classes that use these utilities
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment