Implementation:Pyro ppl Pyro Distribution Base
| Knowledge Sources | |
|---|---|
| Domains | Probability_Distributions |
| Last Updated | 2026-02-09 09:00 GMT |
Overview
Description
The distribution module defines the foundational base class Distribution and its metaclass DistributionMeta for all parameterized probability distributions in Pyro. This module establishes the core interface contract that every Pyro distribution must satisfy.
DistributionMeta is a custom metaclass (extending ABCMeta) that performs two key functions:
- It introspects each distribution class at definition time to extract and store its
__init__signature as__signature__, enabling cleaner inspection and documentation. - It intercepts instance construction (
__call__) to pass through a list of COERCIONS -- a module-level list of callable coercion functions. Each coercion function can inspect the class, positional arguments, and keyword arguments and optionally return a modified distribution instance. If no coercion applies, standard construction proceeds. This mechanism enables Pyro's inference machinery to transparently substitute or wrap distributions.
Distribution is the abstract base class for all Pyro distributions. It defines:
has_rsample(class attribute, defaultFalse) -- Indicates whether the distribution supports reparameterized sampling. Inference engines like SVI use this flag to choose between the reparameterization trick and the score function estimator.has_enumerate_support(class attribute, defaultFalse) -- Indicates whether the distribution's support can be enumerated (for discrete distributions).__call__-- An alias forsample()that allows distribution instances to be called directly as stochastic functions.sample(abstract) -- Draws a random sample from the distribution.log_prob(abstract) -- Evaluates the log probability density (or mass) at a given value.score_parts-- Computes ingredients for stochastic gradient estimators of the ELBO. For reparameterized distributions, returns aScorePartswith the entropy term; for non-reparameterized distributions, returns the score function term.enumerate_support-- Returns a representation of the distribution's discrete support.conjugate_update-- An experimental method for fusing information from a compatible conjugate distribution, returning an updated distribution and a log normalizer.has_rsample_-- An in-place method to force reparameterized or detached sampling on a single instance.rv(property) -- An experimental property that wraps the distribution in aRandomVariableDSL for applying transformations via method chaining or operator overloading.
Usage
This class serves as the root of Pyro's distribution hierarchy. All Pyro distributions (including those wrapping PyTorch distributions via TorchDistribution) ultimately inherit from or are compatible with this interface. Users extending Pyro with custom distributions should implement at minimum the sample and log_prob methods. The COERCIONS mechanism is used internally by Pyro's inference system and should not typically be modified by end users.
Code Reference
Source Location
pyro/distributions/distribution.py
Signature
COERCIONS = []
class DistributionMeta(ABCMeta):
def __init__(cls, *args, **kwargs): ...
def __call__(cls, *args, **kwargs): ...
class Distribution(metaclass=DistributionMeta):
has_rsample = False
has_enumerate_support = False
def __call__(self, *args, **kwargs): ...
def sample(self, *args, **kwargs): ... # abstract
def log_prob(self, x, *args, **kwargs): ... # abstract
def score_parts(self, x, *args, **kwargs): ...
def enumerate_support(self, expand=True): ...
def conjugate_update(self, other): ...
def has_rsample_(self, value): ...
def rv(self): ... # property
Import
from pyro.distributions.distribution import Distribution, COERCIONS
I/O Contract
Inputs
| Method | Parameter | Type | Description |
|---|---|---|---|
sample |
sample_shape |
torch.Size |
Shape of the iid batch to draw from the distribution. |
log_prob |
x |
torch.Tensor |
A single value or batch of values at which to evaluate log probability. |
score_parts |
x |
torch.Tensor |
A single value or batch of values for computing ELBO gradient ingredients. |
enumerate_support |
expand |
bool |
Whether to expand the result to full batch shape. Defaults to True. |
conjugate_update |
other |
Distribution |
latent) normalized over latent. |
has_rsample_ |
value |
bool |
Whether to enable (True) or disable (False) reparameterized sampling. |
Outputs
| Method | Return Type | Description |
|---|---|---|
sample(sample_shape) |
torch.Tensor |
A random sample or batch of samples from the distribution. |
log_prob(x) |
torch.Tensor |
Log probability density or mass evaluated at x.
|
score_parts(x) |
ScoreParts |
A named tuple with log_prob, score_function, and entropy_term fields.
|
enumerate_support(expand) |
torch.Tensor |
Tensor of support values with shape (n,) + batch_shape + event_shape.
|
conjugate_update(other) |
tuple(Distribution, torch.Tensor) |
A pair of (updated distribution, log normalizer). |
has_rsample_(value) |
Distribution |
Returns self after setting has_rsample in-place.
|
rv |
RandomVariable |
A RandomVariable wrapper for the distribution.
|
Usage Examples
import torch
import pyro.distributions as dist
# Basic usage: distributions as stochastic functions
d = dist.Bernoulli(torch.tensor(0.7))
x = d() # equivalent to d.sample()
log_p = d.log_prob(x) # evaluate log probability
# Score parts for gradient estimation
parts = d.score_parts(x)
print("log_prob:", parts.log_prob)
print("score_function:", parts.score_function)
# Toggling reparameterized sampling
normal = dist.Normal(0.0, 1.0)
print("has_rsample:", normal.has_rsample) # True
normal.has_rsample_(False)
print("has_rsample:", normal.has_rsample) # False
# Enumerate support for discrete distributions
bernoulli = dist.Bernoulli(0.3)
support = bernoulli.enumerate_support()
print("Support:", support) # tensor([0., 1.])
# Random variable DSL (experimental)
rv = dist.Uniform(0, 1).rv
transformed = rv.log().neg() # equivalent to Exponential
Related Pages
- Pyro_ppl_Pyro_Constraints -- Constraint objects used for parameter validation in distributions
- Pyro_ppl_Pyro_ConditionalDistribution -- Conditional distribution framework built on the base distribution interface
- Pyro_ppl_Pyro_ConjugateDistributions -- Conjugate distributions that implement the conjugate_update protocol
- Pyro_ppl_Pyro_FoldedDistribution -- A transformed distribution that extends the base interface