Implementation:Pyro ppl Pyro Pyro Distributions
| Knowledge Sources | |
|---|---|
| Domains | Bayesian_Inference, Statistics |
| Type | Wrapper Doc (wraps torch.distributions) |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Pyro's distribution library providing wrappers over PyTorch's torch.distributions classes, augmented with Pyro-specific functionality for use in probabilistic programs.
Description
The pyro.distributions module is the primary interface for specifying probability distributions in Pyro models and guides. These are not Pyro's own distribution implementations -- they are thin wrappers around torch.distributions classes that add Pyro-specific methods and behavior via the TorchDistributionMixin mixin class.
The wrapping adds several capabilities needed for probabilistic programming:
- has_enumerate_support and enumerate_support methods for discrete variable enumeration
- Compatibility with Pyro's effect handler (messenger) system
- Integration with Pyro's pyro.sample primitive for automatic trace recording
- Support for the infer dictionary for inference-time configuration
Key distribution classes available through this module:
| Distribution | Parameters | Support | Typical Use |
|---|---|---|---|
| dist.Normal | loc, scale | Real line | Regression coefficients, general unconstrained parameters |
| dist.HalfCauchy | scale | Positive reals | Scale parameters with heavy tails |
| dist.HalfNormal | scale | Positive reals | Scale parameters with lighter tails |
| dist.StudentT | df, loc, scale | Real line | Robust regression, outlier-resistant models |
| dist.Dirichlet | concentration | Simplex | Mixture weights, probability vectors |
| dist.Bernoulli | probs / logits | {0, 1} | Binary outcomes, coin flips |
| dist.Categorical | probs / logits | {0, ..., K-1} | Discrete classification, mixture assignments |
| dist.Beta | concentration1, concentration0 | [0, 1] | Probabilities, proportions |
| dist.LogNormal | loc, scale | Positive reals | Multiplicative processes, income distributions |
| dist.Gamma | concentration, rate | Positive reals | Waiting times, precision parameters |
Usage
Use pyro.distributions to specify prior distributions and likelihoods in Pyro model functions. Import the module as dist and construct distribution objects by passing the required parameters. These distribution objects are then passed to pyro.sample to declare random variables. For multivariate parameters, use .expand() to broadcast to the desired shape and .to_event(n) to declare the rightmost n dimensions as event dimensions (not batch dimensions).
Code Reference
Source Location
- Repository: pyro
- File: pyro/distributions/torch.py (Pyro wrappers over torch.distributions)
- Wraps: torch.distributions (PyTorch built-in distribution library)
Import
import pyro.distributions as dist
# Individual distributions are accessed as attributes:
# dist.Normal, dist.HalfCauchy, dist.StudentT, etc.
Key Class: TorchDistributionMixin
class TorchDistributionMixin(Distribution):
"""
Mixin class added to all torch.distributions classes when accessed
through pyro.distributions. Adds Pyro-specific methods including
enumerate_support integration and effect handler compatibility.
"""
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| (distribution-specific) | varies | Yes | Each distribution has its own parameters (e.g., loc/scale for Normal, concentration for Dirichlet) |
| validate_args | bool | No | Whether to validate input arguments (default: False in Pyro) |
Outputs
| Name | Type | Description |
|---|---|---|
| return | Distribution object | A Pyro-compatible distribution that can be passed to pyro.sample |
| .sample() | torch.Tensor | Draw a random sample from the distribution |
| .log_prob(value) | torch.Tensor | Compute the log probability density/mass at the given value |
| .expand(batch_shape) | Distribution | Create a batch of independent distributions with the given shape |
| .to_event(n) | Independent | Reinterpret the rightmost n batch dimensions as event dimensions |
Usage Examples
Common Priors for Bayesian Regression
import pyro
import pyro.distributions as dist
def bayesian_regression(x, y=None):
# Normal prior on coefficients (weakly informative)
beta = pyro.sample("beta", dist.Normal(0., 10.))
# HalfCauchy prior on noise scale (heavy tails, positive)
sigma = pyro.sample("sigma", dist.HalfCauchy(2.))
# Likelihood
mu = beta * x
with pyro.plate("data", len(x)):
pyro.sample("obs", dist.Normal(mu, sigma), obs=y)
Robust Regression with StudentT
import pyro
import pyro.distributions as dist
def robust_regression(x, y=None):
beta = pyro.sample("beta", dist.Normal(0., 10.))
sigma = pyro.sample("sigma", dist.HalfNormal(5.))
# Low degrees of freedom for heavy tails (robust to outliers)
nu = pyro.sample("nu", dist.Gamma(2., 0.1))
mu = beta * x
with pyro.plate("data", len(x)):
pyro.sample("obs", dist.StudentT(nu, mu, sigma), obs=y)
Mixture Model with Dirichlet Prior
import torch
import pyro
import pyro.distributions as dist
def mixture_model(data, K=3):
# Dirichlet prior over mixture weights
weights = pyro.sample("weights", dist.Dirichlet(torch.ones(K)))
# Priors for component parameters
with pyro.plate("components", K):
locs = pyro.sample("locs", dist.Normal(0., 10.))
scales = pyro.sample("scales", dist.HalfNormal(5.))
with pyro.plate("data", len(data)):
assignment = pyro.sample("assignment", dist.Categorical(weights),
infer={"enumerate": "parallel"})
pyro.sample("obs", dist.Normal(locs[assignment], scales[assignment]),
obs=data)
Multivariate Prior with .expand() and .to_event()
import pyro
import pyro.distributions as dist
def multivariate_model(x, y=None):
n_features = x.shape[1]
# Expand a scalar prior to a vector and mark as single event
beta = pyro.sample("beta",
dist.Normal(0., 1.).expand([n_features]).to_event(1))
sigma = pyro.sample("sigma", dist.HalfCauchy(1.))
mu = x @ beta
with pyro.plate("data", x.shape[0]):
pyro.sample("obs", dist.Normal(mu, sigma), obs=y)
External References
- PyTorch distributions documentation: torch.distributions
- Pyro distributions documentation: pyro.distributions