Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro Infer Utilities

From Leeroopedia
Revision as of 16:24, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Pyro_ppl_Pyro_Infer_Utilities.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Overview

The util module (Template:Code) provides a collection of utility functions and classes used throughout Pyro's inference subsystem. It includes validation control, tensor helper functions, plate stack management, the MultiFrameTensor container for aggregating costs across plate contexts, the Dice operator implementation for computing expectations with enumerated variables, and the CloneMixin for dataclass cloning.

Key categories of functionality:

  • Validation control -- Template:Code, Template:Code, and the Template:Code context manager for toggling inference validation checks.
  • Tensor helpers -- Template:Code, Template:Code, Template:Code, Template:Code, and Template:Code for safe operations that handle both tensors and plain numbers.
  • Plate utilities -- Template:Code and Template:Code for extracting plate structure from traces.
  • MultiFrameTensor -- A dict-like container that accumulates tensors across different plate contexts and can sum them to a target plate stack.
  • Dice -- An implementation of the DiCE (Infinitely Differentiable Monte Carlo Estimator) operator compatible with Pyro's plate and enumeration features.
  • CloneMixin -- A mixin for cloning Template:Code instances with tensor fields.

Code Reference

File: Template:Code

Validation Functions

Function Description
Template:Code Globally enable or disable inference validation.
Template:Code Returns whether inference validation is currently enabled.
Template:Code Context manager for temporarily enabling/disabling validation.

Tensor Helper Functions

Function Description
Template:Code Like Template:Code but also works with plain numbers.
Template:Code Like Template:Code but is a no-op for numbers or tensors without Template:Code.
Template:Code Like Template:Code but also works with numbers via Template:Code.
Template:Code Like Template:Code but only sums over dims that exist.
Template:Code Sets gradients of a list of tensors to zero in place.

Plate Utility Functions

Function Description
Template:Code Builds a dict mapping site name to a list of vectorized plate frames. Used by Template:Code and Template:Code.
Template:Code Returns a sorted list of plate dims that are not shared by all sites. Used for identifying which dims to sum over.

MultiFrameTensor

Method Description
Template:Code Initializes with optional Template:Code pairs.
Template:Code Adds Template:Code pairs. Tensors with the same plate frames are summed together.
Template:Code Sums all stored tensors down to the given target plate frames, contracting over extra plate dimensions.

Dice

Method Description
Template:Code Initializes from a guide trace and an ordering dict mapping site names to ordinal values.
Template:Code Computes a differentiable expected cost using sum-product contraction, summing over costs at given ordinals.

Other

Function/Class Description
Template:Code Computes DiCE log-factors for a single site, handling enumeration and Monte Carlo sampling.
Template:Code Validates that a guide site is fully reparameterized; raises Template:Code otherwise.
Template:Code Computes log probability sum from a trace while preserving indexing over a specified plate symbol.
Template:Code Mixin adding a Template:Code method to dataclass types with tensor fields.

I/O Contract

MultiFrameTensor

Inputs (via add):

Output (via sum_to):

Dice.compute_expectation

Inputs:

Output:

Usage Examples

Validation Control

from pyro.infer.util import enable_validation, is_validation_enabled, validation_enabled

# Check current state
print(is_validation_enabled())  # True by default in debug mode

# Temporarily disable validation for speed
with validation_enabled(False):
    # run inference without validation checks
    pass

# Re-enable globally
enable_validation(True)

Tensor Helpers

from pyro.infer.util import torch_item, torch_backward, zero_grads

# torch_item works on both tensors and numbers
loss_tensor = torch.tensor(3.14)
loss_number = 3.14
print(torch_item(loss_tensor))   # 3.14
print(torch_item(loss_number))   # 3.14

# torch_backward is a safe no-op for non-differentiable values
torch_backward(loss_number)  # no-op
torch_backward(loss_tensor)  # no-op (no grad_fn)

# Zero out gradients
params = [pyro.param("w")]
zero_grads([p.unconstrained() for p in params])

Using MultiFrameTensor

from pyro.infer.util import MultiFrameTensor

downstream_cost = MultiFrameTensor()
for site in downstream_nodes:
    downstream_cost.add((site["cond_indep_stack"], site["log_prob"]))
summed = downstream_cost.sum_to(target_site["cond_indep_stack"])

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment