Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro Distribution Base

From Leeroopedia


Knowledge Sources
Domains Probability_Distributions
Last Updated 2026-02-09 09:00 GMT

Overview

Description

The distribution module defines the foundational base class Distribution and its metaclass DistributionMeta for all parameterized probability distributions in Pyro. This module establishes the core interface contract that every Pyro distribution must satisfy.

DistributionMeta is a custom metaclass (extending ABCMeta) that performs two key functions:

  1. It introspects each distribution class at definition time to extract and store its __init__ signature as __signature__, enabling cleaner inspection and documentation.
  2. It intercepts instance construction (__call__) to pass through a list of COERCIONS -- a module-level list of callable coercion functions. Each coercion function can inspect the class, positional arguments, and keyword arguments and optionally return a modified distribution instance. If no coercion applies, standard construction proceeds. This mechanism enables Pyro's inference machinery to transparently substitute or wrap distributions.

Distribution is the abstract base class for all Pyro distributions. It defines:

  • has_rsample (class attribute, default False) -- Indicates whether the distribution supports reparameterized sampling. Inference engines like SVI use this flag to choose between the reparameterization trick and the score function estimator.
  • has_enumerate_support (class attribute, default False) -- Indicates whether the distribution's support can be enumerated (for discrete distributions).
  • __call__ -- An alias for sample() that allows distribution instances to be called directly as stochastic functions.
  • sample (abstract) -- Draws a random sample from the distribution.
  • log_prob (abstract) -- Evaluates the log probability density (or mass) at a given value.
  • score_parts -- Computes ingredients for stochastic gradient estimators of the ELBO. For reparameterized distributions, returns a ScoreParts with the entropy term; for non-reparameterized distributions, returns the score function term.
  • enumerate_support -- Returns a representation of the distribution's discrete support.
  • conjugate_update -- An experimental method for fusing information from a compatible conjugate distribution, returning an updated distribution and a log normalizer.
  • has_rsample_ -- An in-place method to force reparameterized or detached sampling on a single instance.
  • rv (property) -- An experimental property that wraps the distribution in a RandomVariable DSL for applying transformations via method chaining or operator overloading.

Usage

This class serves as the root of Pyro's distribution hierarchy. All Pyro distributions (including those wrapping PyTorch distributions via TorchDistribution) ultimately inherit from or are compatible with this interface. Users extending Pyro with custom distributions should implement at minimum the sample and log_prob methods. The COERCIONS mechanism is used internally by Pyro's inference system and should not typically be modified by end users.

Code Reference

Source Location

pyro/distributions/distribution.py

Signature

COERCIONS = []

class DistributionMeta(ABCMeta):
    def __init__(cls, *args, **kwargs): ...
    def __call__(cls, *args, **kwargs): ...

class Distribution(metaclass=DistributionMeta):
    has_rsample = False
    has_enumerate_support = False

    def __call__(self, *args, **kwargs): ...
    def sample(self, *args, **kwargs): ...       # abstract
    def log_prob(self, x, *args, **kwargs): ...  # abstract
    def score_parts(self, x, *args, **kwargs): ...
    def enumerate_support(self, expand=True): ...
    def conjugate_update(self, other): ...
    def has_rsample_(self, value): ...
    def rv(self): ...                            # property

Import

from pyro.distributions.distribution import Distribution, COERCIONS

I/O Contract

Inputs

Method Parameter Type Description
sample sample_shape torch.Size Shape of the iid batch to draw from the distribution.
log_prob x torch.Tensor A single value or batch of values at which to evaluate log probability.
score_parts x torch.Tensor A single value or batch of values for computing ELBO gradient ingredients.
enumerate_support expand bool Whether to expand the result to full batch shape. Defaults to True.
conjugate_update other Distribution latent) normalized over latent.
has_rsample_ value bool Whether to enable (True) or disable (False) reparameterized sampling.

Outputs

Method Return Type Description
sample(sample_shape) torch.Tensor A random sample or batch of samples from the distribution.
log_prob(x) torch.Tensor Log probability density or mass evaluated at x.
score_parts(x) ScoreParts A named tuple with log_prob, score_function, and entropy_term fields.
enumerate_support(expand) torch.Tensor Tensor of support values with shape (n,) + batch_shape + event_shape.
conjugate_update(other) tuple(Distribution, torch.Tensor) A pair of (updated distribution, log normalizer).
has_rsample_(value) Distribution Returns self after setting has_rsample in-place.
rv RandomVariable A RandomVariable wrapper for the distribution.

Usage Examples

import torch
import pyro.distributions as dist

# Basic usage: distributions as stochastic functions
d = dist.Bernoulli(torch.tensor(0.7))
x = d()                      # equivalent to d.sample()
log_p = d.log_prob(x)        # evaluate log probability

# Score parts for gradient estimation
parts = d.score_parts(x)
print("log_prob:", parts.log_prob)
print("score_function:", parts.score_function)

# Toggling reparameterized sampling
normal = dist.Normal(0.0, 1.0)
print("has_rsample:", normal.has_rsample)  # True
normal.has_rsample_(False)
print("has_rsample:", normal.has_rsample)  # False

# Enumerate support for discrete distributions
bernoulli = dist.Bernoulli(0.3)
support = bernoulli.enumerate_support()
print("Support:", support)  # tensor([0., 1.])

# Random variable DSL (experimental)
rv = dist.Uniform(0, 1).rv
transformed = rv.log().neg()  # equivalent to Exponential

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment