Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro Util

From Leeroopedia


Property Value
Module pyro.util
Source pyro/util.py
Lines 724
Functions set_rng_seed, get_rng_state, set_rng_state, torch_isnan, torch_isinf, warn_if_nan, warn_if_inf, save_visualization, check_traces_match, check_model_guide_match, check_site_shape, check_traceenum_requirements, check_if_enumerated, ignore_jit_warnings, jit_iter, torch_float
Classes optional, ExperimentalWarning, timed
Dependencies torch, numpy, pyro.poutine

Overview

This module provides a collection of utility functions and classes used throughout the Pyro codebase. The utilities fall into several categories:

  • RNG management: Functions for deterministically setting and restoring random state across PyTorch, Python, and NumPy.
  • Numerical diagnostics: Functions to detect and warn about NaN and Inf values in tensors (including their gradients).
  • Trace validation: Functions for checking consistency between model and guide traces, including shape validation, plate checking, and enumeration compatibility.
  • JIT helpers: Context managers for suppressing JIT tracer warnings during tracing.
  • Convenience utilities: Optional context manager, experimental warning class, timing context manager.

Code Reference

RNG Management

  • set_rng_seed(rng_seed): Seeds torch, random, and numpy simultaneously.
  • get_rng_state(): Captures the RNG state from all three libraries as a dict.
  • set_rng_state(state): Restores previously captured RNG state.

Numerical Diagnostics

  • torch_isnan(x): Checks if a tensor (or number) contains any NaN.
  • torch_isinf(x): Checks if a tensor (or number) contains any +/-inf.
  • warn_if_nan(value, msg): Issues a UserWarning if NaN is detected. Also registers a backward hook to check gradients for NaN.
  • warn_if_inf(value, msg, allow_posinf, allow_neginf): Issues a UserWarning if Inf is detected, with options to allow positive or negative infinity.

Trace Validation

  • check_traces_match(trace1, trace2): Verifies bijection between sample sites and agreement on shapes.
  • check_model_guide_match(model_trace, guide_trace, max_plate_nesting): Comprehensive validation checking:
    1. Each model sample appears in guide (and vice versa, modulo auxiliary vars).
    2. Each guide plate appears in model.
    3. Model/guide agree on sample shapes at shared sites.
    4. Model-side sequential enumeration is not used.
    5. Factor statements specify has_rsample.
  • check_site_shape(site, max_plate_nesting): Validates that a site's log_prob shape matches its cond_indep_stack (plate structure).
  • check_traceenum_requirements(model_trace, guide_trace): Warns about potential invalid dependencies on enumerated variables.
  • check_if_enumerated(guide_trace): Warns if enumerated sites are found without TraceEnum_ELBO.

JIT Helpers

  • ignore_jit_warnings(filter): Context manager that filters JIT tracer warnings. Only active during tracing.
  • jit_iter(tensor): Iterates over a tensor while suppressing "Iterating over a tensor" warnings.

Convenience Utilities

  • optional(context_manager, condition): Context manager that conditionally wraps.
  • ExperimentalWarning: Custom warning class for experimental features.
  • ignore_experimental_warning(): Context manager to suppress ExperimentalWarnings.
  • timed: Context manager that measures wall-clock time.
  • torch_float(x): Converts a number or tensor to float.

I/O Contract

Function Input Output
set_rng_seed(seed) int None
get_rng_state() None Dict[str, Any]
torch_isnan(x) Tensor or Number bool or Tensor
warn_if_nan(value, msg) Tensor or Number, str Same as input (pass-through)
check_model_guide_match(model, guide) Two Trace objects None (raises on error)
check_site_shape(site, max_plate_nesting) Message dict, int None (raises on error)

Usage Examples

import torch
from pyro.util import (
    set_rng_seed,
    get_rng_state,
    set_rng_state,
    warn_if_nan,
    ignore_jit_warnings,
    timed,
    optional,
)

# Reproducible experiments
set_rng_seed(42)
x = torch.randn(3)

# Save and restore RNG state
state = get_rng_state()
a = torch.randn(3)
set_rng_state(state)
b = torch.randn(3)
assert torch.equal(a, b)

# NaN detection with automatic backward hook
x = torch.tensor([1.0, float('nan'), 3.0], requires_grad=True)
x = warn_if_nan(x, "forward pass")  # warning issued

# Timing a block of code
with timed() as timer:
    result = torch.randn(1000, 1000) @ torch.randn(1000, 1000)
print(f"Elapsed: {timer.elapsed:.4f} seconds")

# Conditional context manager
use_cuda = torch.cuda.is_available()
with optional(torch.cuda.amp.autocast(), use_cuda):
    result = model(data)

# JIT-safe code
with ignore_jit_warnings():
    traced_fn = torch.jit.trace(my_fn, example_input)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment