Implementation:Pyro ppl Pyro Distribution Utilities
| Knowledge Sources | |
|---|---|
| Domains | Probability_Distributions, Tensor_Operations |
| Last Updated | 2026-02-09 09:00 GMT |
Overview
A collection of utility functions for shape broadcasting, tensor manipulation, validation control, and other common operations used throughout Pyro's distribution library.
Description
This module provides essential utility functions that support the implementation of distributions across the Pyro framework. The functions fall into several categories:
Shape Broadcasting and Manipulation:
- broadcast_shape(*shapes) -- Computes the broadcast-compatible shape from multiple input shapes, similar to
np.broadcast. Supports astrictmode that prevents extending size-1 dimensions, which is useful for shape validation. - sum_rightmost(value, dim) -- Sums out the rightmost
dimdimensions of a tensor. Handles special cases likedim=0(no-op),dim=float('inf')(sum all), and negative values (sum all but leftmost dims). Critical for reducing event dimensions in log-prob computations. - sum_leftmost(value, dim) -- The mirror of
sum_rightmost, summing out leftmost dimensions. - gather(value, index, dim) -- A broadcasted version of
torch.gatherthat aligns shapes before gathering.
Scalar and Tensor Checks:
- is_identically_zero(x) -- Checks if a value is exactly zero (not a tensor zero), used to short-circuit unnecessary computations.
- is_identically_one(x) -- Checks if a value is exactly one, used to skip identity-multiply operations.
Scaling and Masking:
- scale_and_mask(tensor, scale, mask) -- Applies scaling and boolean masking to a tensor, with optimizations to avoid unnecessary operations. Returns zero tensors where mask is False, and avoids multiplication when scale is identically one.
Tensor Construction:
- scalar_like(prototype, fill_value) -- Creates a scalar tensor matching the dtype and device of a prototype tensor.
- eye_like(value, m, n) -- Creates an identity matrix matching the dtype and device of a given tensor, working around JIT limitations.
Object Manipulation:
- copy_docs_from(source_class) -- Decorator that copies docstrings from a source class to a destination class.
- weakmethod(fn) -- Decorator that enforces weak binding of methods to avoid reference cycles, useful when passing bound methods as callbacks.
- detach(obj) -- Creates a deep copy of any Python object, detaching all tensors without copying data.
- deep_to(obj, *args, **kwargs) -- Creates a deep copy of any Python object, calling
.to()on all tensors and modules.
Validation Control:
- enable_validation(is_validate) -- Globally enables or disables distribution argument validation for both Pyro and PyTorch.
- is_validation_enabled() -- Returns the current validation state.
- validation_enabled(is_validate) -- Context manager for temporarily enabling or disabling validation.
The module also maintains a global _VALIDATION_ENABLED flag initialized from __debug__ and registers Pyro settings for both Pyro and PyTorch validation states.
Usage
These utilities are used internally throughout Pyro's distribution implementations. The broadcast_shape and sum_rightmost functions are particularly central to distribution shape management. Use scale_and_mask when implementing custom distributions that need efficient masking. Use the validation functions during debugging and testing.
Code Reference
Source Location
- Repository: Pyro
- File: pyro/distributions/util.py
Signature
def broadcast_shape(*shapes, **kwargs) -> tuple
def gather(value, index, dim) -> torch.Tensor
def sum_rightmost(value, dim) -> torch.Tensor
def sum_leftmost(value, dim) -> torch.Tensor
def scale_and_mask(tensor, scale=1.0, mask=None) -> torch.Tensor
def is_identically_zero(x) -> bool
def is_identically_one(x) -> bool
def scalar_like(prototype, fill_value) -> torch.Tensor
def eye_like(value, m, n=None) -> torch.Tensor
def enable_validation(is_validate) -> None
def is_validation_enabled() -> bool
def validation_enabled(is_validate=True) # context manager
def copy_docs_from(source_class, full_text=False) # decorator
def weakmethod(fn) # decorator
def detach(obj)
def deep_to(obj, *args, **kwargs)
Import
from pyro.distributions.util import broadcast_shape
from pyro.distributions.util import sum_rightmost, sum_leftmost
from pyro.distributions.util import scale_and_mask
from pyro.distributions.util import is_identically_zero, is_identically_one
from pyro.distributions.util import enable_validation, is_validation_enabled, validation_enabled
from pyro.distributions.util import detach, deep_to
from pyro.distributions.util import gather, scalar_like, eye_like
from pyro.distributions.util import copy_docs_from, weakmethod
I/O Contract
broadcast_shape
| Name | Type | Required | Description |
|---|---|---|---|
| *shapes | tuple | Yes | Variable number of shape tuples to broadcast together. |
| strict | bool | No | If True, prevents extending size-1 dimensions. Defaults to False. |
| Name | Type | Description |
|---|---|---|
| return | tuple | The resulting broadcasted shape. Raises ValueError if shapes are incompatible. |
sum_rightmost
| Name | Type | Required | Description |
|---|---|---|---|
| value | torch.Tensor or number | Yes | Input tensor or scalar. |
| dim | int | Yes | Number of rightmost dimensions to sum. Supports 0 (no-op), positive (sum rightmost N), negative (keep leftmost abs(dim)), and float('inf') (sum all). |
| Name | Type | Description |
|---|---|---|
| return | torch.Tensor or number | Reduced tensor with rightmost dims summed out. |
scale_and_mask
| Name | Type | Required | Description |
|---|---|---|---|
| tensor | torch.Tensor or 0 | Yes | Input tensor or zero. |
| scale | torch.Tensor or number | No | Positive scale factor. Defaults to 1.0. |
| mask | torch.BoolTensor, bool, or None | No | Optional boolean mask. None means no masking. |
| Name | Type | Description |
|---|---|---|
| return | torch.Tensor | Scaled and masked tensor. Zero where mask is False. |
detach
| Name | Type | Required | Description |
|---|---|---|---|
| obj | Any | Yes | Any Python object to deep-copy with detached tensors. |
| Name | Type | Description |
|---|---|---|
| return | Any | Deep copy of the input with all tensors detached (no data copied). |
deep_to
| Name | Type | Required | Description |
|---|---|---|---|
| obj | Any | Yes | Any Python object to deep-copy. |
| *args | Any | No | Arguments passed to tensor.to() (e.g., device, dtype).
|
| **kwargs | Any | No | Keyword arguments passed to tensor.to().
|
| Name | Type | Description |
|---|---|---|
| return | Any | Deep copy with all tensors and modules mapped via .to().
|
Usage Examples
Broadcasting Shapes
from pyro.distributions.util import broadcast_shape
shape = broadcast_shape((2, 1), (1, 3))
print(shape) # (2, 3)
shape = broadcast_shape((5,), (2, 1), (1, 3, 1))
print(shape) # (1, 3, 5)
# Strict mode prevents extending size-1
try:
broadcast_shape((2, 1), (2, 3), strict=True)
except ValueError:
print("Strict mode rejects size-1 extension")
Summing Rightmost Dimensions
import torch
from pyro.distributions.util import sum_rightmost
x = torch.ones(2, 3, 4, 5)
# Sum rightmost 2 dims
y = sum_rightmost(x, 2)
print(y.shape) # torch.Size([2, 3])
# Keep leftmost 1 dim
y = sum_rightmost(x, -1)
print(y.shape) # torch.Size([2])
Scale and Mask
import torch
from pyro.distributions.util import scale_and_mask
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0])
mask = torch.tensor([True, True, False, True])
result = scale_and_mask(tensor, scale=2.0, mask=mask)
print(result) # tensor([2., 4., 0., 8.])
Validation Context Manager
from pyro.distributions.util import validation_enabled, is_validation_enabled
print(is_validation_enabled()) # True (if __debug__)
with validation_enabled(False):
print(is_validation_enabled()) # False
# Construct distributions without validation
pass
print(is_validation_enabled()) # True (restored)
Moving Objects to Device
import torch
from pyro.distributions.util import deep_to
obj = {"weights": torch.randn(3, 4), "bias": torch.zeros(4)}
obj_gpu = deep_to(obj, device="cuda")
# All tensors in obj_gpu are on CUDA