Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Principle:Pyro ppl Pyro Custom Distribution Framework

From Leeroopedia


Knowledge Sources
Domains Probabilistic Programming, Probability Theory, Deep Learning
Last Updated 2026-02-09 09:00 GMT

Overview

Pyro's custom distribution framework extends PyTorch's distribution library with additional capabilities required for probabilistic programming, including enumeration support, reparameterization metadata, and compatibility with Pyro's effect handler system.

Description

PyTorch provides a foundational torch.distributions module implementing common probability distributions. Pyro builds an abstraction layer on top of this foundation to support the richer requirements of a probabilistic programming language.

The framework addresses several concerns:

Bridging PyTorch and Pyro semantics: PyTorch distributions provide sample() and log_prob() methods. Pyro's layer adds methods like enumerate_support() for discrete enumeration, .has_rsample for checking reparameterization support, and batch/event shape manipulation via .expand() and .to_event().

The TorchDistributionMixin: This mixin class injects Pyro-specific behavior into any PyTorch distribution. It provides:

  • Proper integration with Pyro's sample statement
  • Shape handling that distinguishes batch dimensions (independent repetitions) from event dimensions (dimensions that jointly define a single draw)
  • Support for pyro.plate contexts for vectorized conditional independence

Distribution wrappers: Pyro provides wrapper classes that adapt PyTorch distributions (or arbitrary log-density functions) into Pyro-compatible distribution objects. This enables users to define custom likelihoods while retaining full compatibility with all inference algorithms.

Constraints system: Each distribution declares the support (the set of valid values) and parameter constraints via constraint objects. These constraints are used for:

  • Automatic parameter transforms (e.g., mapping unconstrained parameters to positive reals via exp)
  • Validation of sample values
  • Initialization of variational parameters

Transform modules: Distributions can be composed with bijective transforms to create transformed distributions (e.g., a LogNormal is a Normal passed through an exp transform). Pyro extends this with TransformModule, which makes transform parameters learnable.

Usage

Use the custom distribution framework when:

  • Defining new probability distributions not available in PyTorch.
  • Wrapping an arbitrary log-density function as a distribution for use in Pyro models.
  • Composing distributions with learnable transforms for normalizing flows.
  • Controlling batch vs. event shape semantics in complex models.
  • Ensuring custom distributions work correctly with all of Pyro's inference algorithms.

Theoretical Basis

A probability distribution in this framework is characterized by:

# Distribution interface
class Distribution:
    batch_shape: tuple   # shape for independent draws
    event_shape: tuple   # shape of a single event

    def sample(sample_shape) -> Tensor:
        # Returns tensor of shape: sample_shape + batch_shape + event_shape
        ...

    def log_prob(value) -> Tensor:
        # value: shape batch_shape + event_shape
        # Returns: shape batch_shape (log-density summed over event dims)
        ...

    def has_rsample -> bool:
        # True if sample() supports reparameterization gradients
        # i.e., sample = g(epsilon, params) where epsilon ~ base_dist
        ...

The shape semantics are critical:

# Shape contract
# Given:
#   sample_shape = (S,)    # number of MC samples
#   batch_shape  = (B,)    # independent repetitions
#   event_shape  = (E,)    # dimensions of one draw

# Then:
#   x = dist.sample((S,))  has shape (S, B, E)
#   dist.log_prob(x)       has shape (S, B)
#   (log_prob sums over event_shape, preserves batch_shape)

Reparameterization enables pathwise gradient estimation:

# Reparameterization trick
# Instead of: z ~ q(z|params)
# Write:      epsilon ~ base_distribution
#             z = transform(epsilon, params)
# Then: grad_params E[f(z)] = E[grad_params f(transform(epsilon, params))]

Constraint transforms map between constrained and unconstrained spaces:

# Constraint: x > 0  (positive reals)
# Transform: x = exp(y) where y is unconstrained
# Jacobian correction: log_prob_x(x) = log_prob_y(log(x)) - log(x)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment