Principle:Pyro ppl Pyro Custom Distribution Framework
| Knowledge Sources | |
|---|---|
| Domains | Probabilistic Programming, Probability Theory, Deep Learning |
| Last Updated | 2026-02-09 09:00 GMT |
Overview
Pyro's custom distribution framework extends PyTorch's distribution library with additional capabilities required for probabilistic programming, including enumeration support, reparameterization metadata, and compatibility with Pyro's effect handler system.
Description
PyTorch provides a foundational torch.distributions module implementing common probability distributions. Pyro builds an abstraction layer on top of this foundation to support the richer requirements of a probabilistic programming language.
The framework addresses several concerns:
Bridging PyTorch and Pyro semantics: PyTorch distributions provide sample() and log_prob() methods. Pyro's layer adds methods like enumerate_support() for discrete enumeration, .has_rsample for checking reparameterization support, and batch/event shape manipulation via .expand() and .to_event().
The TorchDistributionMixin: This mixin class injects Pyro-specific behavior into any PyTorch distribution. It provides:
- Proper integration with Pyro's
samplestatement - Shape handling that distinguishes batch dimensions (independent repetitions) from event dimensions (dimensions that jointly define a single draw)
- Support for
pyro.platecontexts for vectorized conditional independence
Distribution wrappers: Pyro provides wrapper classes that adapt PyTorch distributions (or arbitrary log-density functions) into Pyro-compatible distribution objects. This enables users to define custom likelihoods while retaining full compatibility with all inference algorithms.
Constraints system: Each distribution declares the support (the set of valid values) and parameter constraints via constraint objects. These constraints are used for:
- Automatic parameter transforms (e.g., mapping unconstrained parameters to positive reals via
exp) - Validation of sample values
- Initialization of variational parameters
Transform modules: Distributions can be composed with bijective transforms to create transformed distributions (e.g., a LogNormal is a Normal passed through an exp transform). Pyro extends this with TransformModule, which makes transform parameters learnable.
Usage
Use the custom distribution framework when:
- Defining new probability distributions not available in PyTorch.
- Wrapping an arbitrary log-density function as a distribution for use in Pyro models.
- Composing distributions with learnable transforms for normalizing flows.
- Controlling batch vs. event shape semantics in complex models.
- Ensuring custom distributions work correctly with all of Pyro's inference algorithms.
Theoretical Basis
A probability distribution in this framework is characterized by:
# Distribution interface
class Distribution:
batch_shape: tuple # shape for independent draws
event_shape: tuple # shape of a single event
def sample(sample_shape) -> Tensor:
# Returns tensor of shape: sample_shape + batch_shape + event_shape
...
def log_prob(value) -> Tensor:
# value: shape batch_shape + event_shape
# Returns: shape batch_shape (log-density summed over event dims)
...
def has_rsample -> bool:
# True if sample() supports reparameterization gradients
# i.e., sample = g(epsilon, params) where epsilon ~ base_dist
...
The shape semantics are critical:
# Shape contract
# Given:
# sample_shape = (S,) # number of MC samples
# batch_shape = (B,) # independent repetitions
# event_shape = (E,) # dimensions of one draw
# Then:
# x = dist.sample((S,)) has shape (S, B, E)
# dist.log_prob(x) has shape (S, B)
# (log_prob sums over event_shape, preserves batch_shape)
Reparameterization enables pathwise gradient estimation:
# Reparameterization trick
# Instead of: z ~ q(z|params)
# Write: epsilon ~ base_distribution
# z = transform(epsilon, params)
# Then: grad_params E[f(z)] = E[grad_params f(transform(epsilon, params))]
Constraint transforms map between constrained and unconstrained spaces:
# Constraint: x > 0 (positive reals)
# Transform: x = exp(y) where y is unconstrained
# Jacobian correction: log_prob_x(x) = log_prob_y(log(x)) - log(x)
Related Pages
- Implementation:Pyro_ppl_Pyro_Distribution_Base
- Implementation:Pyro_ppl_Pyro_TorchDistributionMixin
- Implementation:Pyro_ppl_Pyro_TorchDistribution_Wrappers
- Implementation:Pyro_ppl_Pyro_Distributions_Init
- Implementation:Pyro_ppl_Pyro_Constraints
- Implementation:Pyro_ppl_Pyro_Distribution_Utilities
- Implementation:Pyro_ppl_Pyro_TransformModule