Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro Distribution Utilities

From Leeroopedia


Knowledge Sources
Domains Probability_Distributions, Tensor_Operations
Last Updated 2026-02-09 09:00 GMT

Overview

A collection of utility functions for shape broadcasting, tensor manipulation, validation control, and other common operations used throughout Pyro's distribution library.

Description

This module provides essential utility functions that support the implementation of distributions across the Pyro framework. The functions fall into several categories:

Shape Broadcasting and Manipulation:

  • broadcast_shape(*shapes) -- Computes the broadcast-compatible shape from multiple input shapes, similar to np.broadcast. Supports a strict mode that prevents extending size-1 dimensions, which is useful for shape validation.
  • sum_rightmost(value, dim) -- Sums out the rightmost dim dimensions of a tensor. Handles special cases like dim=0 (no-op), dim=float('inf') (sum all), and negative values (sum all but leftmost dims). Critical for reducing event dimensions in log-prob computations.
  • sum_leftmost(value, dim) -- The mirror of sum_rightmost, summing out leftmost dimensions.
  • gather(value, index, dim) -- A broadcasted version of torch.gather that aligns shapes before gathering.

Scalar and Tensor Checks:

  • is_identically_zero(x) -- Checks if a value is exactly zero (not a tensor zero), used to short-circuit unnecessary computations.
  • is_identically_one(x) -- Checks if a value is exactly one, used to skip identity-multiply operations.

Scaling and Masking:

  • scale_and_mask(tensor, scale, mask) -- Applies scaling and boolean masking to a tensor, with optimizations to avoid unnecessary operations. Returns zero tensors where mask is False, and avoids multiplication when scale is identically one.

Tensor Construction:

  • scalar_like(prototype, fill_value) -- Creates a scalar tensor matching the dtype and device of a prototype tensor.
  • eye_like(value, m, n) -- Creates an identity matrix matching the dtype and device of a given tensor, working around JIT limitations.

Object Manipulation:

  • copy_docs_from(source_class) -- Decorator that copies docstrings from a source class to a destination class.
  • weakmethod(fn) -- Decorator that enforces weak binding of methods to avoid reference cycles, useful when passing bound methods as callbacks.
  • detach(obj) -- Creates a deep copy of any Python object, detaching all tensors without copying data.
  • deep_to(obj, *args, **kwargs) -- Creates a deep copy of any Python object, calling .to() on all tensors and modules.

Validation Control:

  • enable_validation(is_validate) -- Globally enables or disables distribution argument validation for both Pyro and PyTorch.
  • is_validation_enabled() -- Returns the current validation state.
  • validation_enabled(is_validate) -- Context manager for temporarily enabling or disabling validation.

The module also maintains a global _VALIDATION_ENABLED flag initialized from __debug__ and registers Pyro settings for both Pyro and PyTorch validation states.

Usage

These utilities are used internally throughout Pyro's distribution implementations. The broadcast_shape and sum_rightmost functions are particularly central to distribution shape management. Use scale_and_mask when implementing custom distributions that need efficient masking. Use the validation functions during debugging and testing.

Code Reference

Source Location

Signature

def broadcast_shape(*shapes, **kwargs) -> tuple
def gather(value, index, dim) -> torch.Tensor
def sum_rightmost(value, dim) -> torch.Tensor
def sum_leftmost(value, dim) -> torch.Tensor
def scale_and_mask(tensor, scale=1.0, mask=None) -> torch.Tensor
def is_identically_zero(x) -> bool
def is_identically_one(x) -> bool
def scalar_like(prototype, fill_value) -> torch.Tensor
def eye_like(value, m, n=None) -> torch.Tensor
def enable_validation(is_validate) -> None
def is_validation_enabled() -> bool
def validation_enabled(is_validate=True)  # context manager
def copy_docs_from(source_class, full_text=False)  # decorator
def weakmethod(fn)  # decorator
def detach(obj)
def deep_to(obj, *args, **kwargs)

Import

from pyro.distributions.util import broadcast_shape
from pyro.distributions.util import sum_rightmost, sum_leftmost
from pyro.distributions.util import scale_and_mask
from pyro.distributions.util import is_identically_zero, is_identically_one
from pyro.distributions.util import enable_validation, is_validation_enabled, validation_enabled
from pyro.distributions.util import detach, deep_to
from pyro.distributions.util import gather, scalar_like, eye_like
from pyro.distributions.util import copy_docs_from, weakmethod

I/O Contract

broadcast_shape

Name Type Required Description
*shapes tuple Yes Variable number of shape tuples to broadcast together.
strict bool No If True, prevents extending size-1 dimensions. Defaults to False.
Name Type Description
return tuple The resulting broadcasted shape. Raises ValueError if shapes are incompatible.

sum_rightmost

Name Type Required Description
value torch.Tensor or number Yes Input tensor or scalar.
dim int Yes Number of rightmost dimensions to sum. Supports 0 (no-op), positive (sum rightmost N), negative (keep leftmost abs(dim)), and float('inf') (sum all).
Name Type Description
return torch.Tensor or number Reduced tensor with rightmost dims summed out.

scale_and_mask

Name Type Required Description
tensor torch.Tensor or 0 Yes Input tensor or zero.
scale torch.Tensor or number No Positive scale factor. Defaults to 1.0.
mask torch.BoolTensor, bool, or None No Optional boolean mask. None means no masking.
Name Type Description
return torch.Tensor Scaled and masked tensor. Zero where mask is False.

detach

Name Type Required Description
obj Any Yes Any Python object to deep-copy with detached tensors.
Name Type Description
return Any Deep copy of the input with all tensors detached (no data copied).

deep_to

Name Type Required Description
obj Any Yes Any Python object to deep-copy.
*args Any No Arguments passed to tensor.to() (e.g., device, dtype).
**kwargs Any No Keyword arguments passed to tensor.to().
Name Type Description
return Any Deep copy with all tensors and modules mapped via .to().

Usage Examples

Broadcasting Shapes

from pyro.distributions.util import broadcast_shape

shape = broadcast_shape((2, 1), (1, 3))
print(shape)  # (2, 3)

shape = broadcast_shape((5,), (2, 1), (1, 3, 1))
print(shape)  # (1, 3, 5)

# Strict mode prevents extending size-1
try:
    broadcast_shape((2, 1), (2, 3), strict=True)
except ValueError:
    print("Strict mode rejects size-1 extension")

Summing Rightmost Dimensions

import torch
from pyro.distributions.util import sum_rightmost

x = torch.ones(2, 3, 4, 5)

# Sum rightmost 2 dims
y = sum_rightmost(x, 2)
print(y.shape)  # torch.Size([2, 3])

# Keep leftmost 1 dim
y = sum_rightmost(x, -1)
print(y.shape)  # torch.Size([2])

Scale and Mask

import torch
from pyro.distributions.util import scale_and_mask

tensor = torch.tensor([1.0, 2.0, 3.0, 4.0])
mask = torch.tensor([True, True, False, True])

result = scale_and_mask(tensor, scale=2.0, mask=mask)
print(result)  # tensor([2., 4., 0., 8.])

Validation Context Manager

from pyro.distributions.util import validation_enabled, is_validation_enabled

print(is_validation_enabled())  # True (if __debug__)

with validation_enabled(False):
    print(is_validation_enabled())  # False
    # Construct distributions without validation
    pass

print(is_validation_enabled())  # True (restored)

Moving Objects to Device

import torch
from pyro.distributions.util import deep_to

obj = {"weights": torch.randn(3, 4), "bias": torch.zeros(4)}
obj_gpu = deep_to(obj, device="cuda")
# All tensors in obj_gpu are on CUDA

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment