Implementation:Pyro ppl Pyro Util
Appearance
| Property | Value |
|---|---|
| Module | pyro.util
|
| Source | pyro/util.py |
| Lines | 724 |
| Functions | set_rng_seed, get_rng_state, set_rng_state, torch_isnan, torch_isinf, warn_if_nan, warn_if_inf, save_visualization, check_traces_match, check_model_guide_match, check_site_shape, check_traceenum_requirements, check_if_enumerated, ignore_jit_warnings, jit_iter, torch_float
|
| Classes | optional, ExperimentalWarning, timed
|
| Dependencies | torch, numpy, pyro.poutine
|
Overview
This module provides a collection of utility functions and classes used throughout the Pyro codebase. The utilities fall into several categories:
- RNG management: Functions for deterministically setting and restoring random state across PyTorch, Python, and NumPy.
- Numerical diagnostics: Functions to detect and warn about NaN and Inf values in tensors (including their gradients).
- Trace validation: Functions for checking consistency between model and guide traces, including shape validation, plate checking, and enumeration compatibility.
- JIT helpers: Context managers for suppressing JIT tracer warnings during tracing.
- Convenience utilities: Optional context manager, experimental warning class, timing context manager.
Code Reference
RNG Management
set_rng_seed(rng_seed): Seedstorch,random, andnumpysimultaneously.get_rng_state(): Captures the RNG state from all three libraries as a dict.set_rng_state(state): Restores previously captured RNG state.
Numerical Diagnostics
torch_isnan(x): Checks if a tensor (or number) contains any NaN.torch_isinf(x): Checks if a tensor (or number) contains any +/-inf.warn_if_nan(value, msg): Issues aUserWarningif NaN is detected. Also registers a backward hook to check gradients for NaN.warn_if_inf(value, msg, allow_posinf, allow_neginf): Issues aUserWarningif Inf is detected, with options to allow positive or negative infinity.
Trace Validation
check_traces_match(trace1, trace2): Verifies bijection between sample sites and agreement on shapes.check_model_guide_match(model_trace, guide_trace, max_plate_nesting): Comprehensive validation checking:- Each model sample appears in guide (and vice versa, modulo auxiliary vars).
- Each guide plate appears in model.
- Model/guide agree on sample shapes at shared sites.
- Model-side sequential enumeration is not used.
- Factor statements specify
has_rsample.
check_site_shape(site, max_plate_nesting): Validates that a site'slog_probshape matches itscond_indep_stack(plate structure).check_traceenum_requirements(model_trace, guide_trace): Warns about potential invalid dependencies on enumerated variables.check_if_enumerated(guide_trace): Warns if enumerated sites are found without TraceEnum_ELBO.
JIT Helpers
ignore_jit_warnings(filter): Context manager that filters JIT tracer warnings. Only active during tracing.jit_iter(tensor): Iterates over a tensor while suppressing "Iterating over a tensor" warnings.
Convenience Utilities
optional(context_manager, condition): Context manager that conditionally wraps.ExperimentalWarning: Custom warning class for experimental features.ignore_experimental_warning(): Context manager to suppressExperimentalWarnings.timed: Context manager that measures wall-clock time.torch_float(x): Converts a number or tensor to float.
I/O Contract
| Function | Input | Output |
|---|---|---|
set_rng_seed(seed) |
int |
None |
get_rng_state() |
None | Dict[str, Any]
|
torch_isnan(x) |
Tensor or Number |
bool or Tensor
|
warn_if_nan(value, msg) |
Tensor or Number, str |
Same as input (pass-through) |
check_model_guide_match(model, guide) |
Two Trace objects |
None (raises on error) |
check_site_shape(site, max_plate_nesting) |
Message dict, int |
None (raises on error) |
Usage Examples
import torch
from pyro.util import (
set_rng_seed,
get_rng_state,
set_rng_state,
warn_if_nan,
ignore_jit_warnings,
timed,
optional,
)
# Reproducible experiments
set_rng_seed(42)
x = torch.randn(3)
# Save and restore RNG state
state = get_rng_state()
a = torch.randn(3)
set_rng_state(state)
b = torch.randn(3)
assert torch.equal(a, b)
# NaN detection with automatic backward hook
x = torch.tensor([1.0, float('nan'), 3.0], requires_grad=True)
x = warn_if_nan(x, "forward pass") # warning issued
# Timing a block of code
with timed() as timer:
result = torch.randn(1000, 1000) @ torch.randn(1000, 1000)
print(f"Elapsed: {timer.elapsed:.4f} seconds")
# Conditional context manager
use_cuda = torch.cuda.is_available()
with optional(torch.cuda.amp.autocast(), use_cuda):
result = model(data)
# JIT-safe code
with ignore_jit_warnings():
traced_fn = torch.jit.trace(my_fn, example_input)
Related Pages
- Pyro_ppl_Pyro_Settings -- Global Pyro settings management
- Pyro_ppl_Pyro_LazyJIT -- JIT compilation with parameter management
- Pyro_ppl_Pyro_MiniPyro -- Minimal Pyro implementation
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment