Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Pyro ppl Pyro Pyro Distributions

From Leeroopedia


Knowledge Sources
Domains Bayesian_Inference, Statistics
Type Wrapper Doc (wraps torch.distributions)
Last Updated 2026-02-09 00:00 GMT

Overview

Pyro's distribution library providing wrappers over PyTorch's torch.distributions classes, augmented with Pyro-specific functionality for use in probabilistic programs.

Description

The pyro.distributions module is the primary interface for specifying probability distributions in Pyro models and guides. These are not Pyro's own distribution implementations -- they are thin wrappers around torch.distributions classes that add Pyro-specific methods and behavior via the TorchDistributionMixin mixin class.

The wrapping adds several capabilities needed for probabilistic programming:

  • has_enumerate_support and enumerate_support methods for discrete variable enumeration
  • Compatibility with Pyro's effect handler (messenger) system
  • Integration with Pyro's pyro.sample primitive for automatic trace recording
  • Support for the infer dictionary for inference-time configuration

Key distribution classes available through this module:

Distribution Parameters Support Typical Use
dist.Normal loc, scale Real line Regression coefficients, general unconstrained parameters
dist.HalfCauchy scale Positive reals Scale parameters with heavy tails
dist.HalfNormal scale Positive reals Scale parameters with lighter tails
dist.StudentT df, loc, scale Real line Robust regression, outlier-resistant models
dist.Dirichlet concentration Simplex Mixture weights, probability vectors
dist.Bernoulli probs / logits {0, 1} Binary outcomes, coin flips
dist.Categorical probs / logits {0, ..., K-1} Discrete classification, mixture assignments
dist.Beta concentration1, concentration0 [0, 1] Probabilities, proportions
dist.LogNormal loc, scale Positive reals Multiplicative processes, income distributions
dist.Gamma concentration, rate Positive reals Waiting times, precision parameters

Usage

Use pyro.distributions to specify prior distributions and likelihoods in Pyro model functions. Import the module as dist and construct distribution objects by passing the required parameters. These distribution objects are then passed to pyro.sample to declare random variables. For multivariate parameters, use .expand() to broadcast to the desired shape and .to_event(n) to declare the rightmost n dimensions as event dimensions (not batch dimensions).

Code Reference

Source Location

  • Repository: pyro
  • File: pyro/distributions/torch.py (Pyro wrappers over torch.distributions)
  • Wraps: torch.distributions (PyTorch built-in distribution library)

Import

import pyro.distributions as dist

# Individual distributions are accessed as attributes:
# dist.Normal, dist.HalfCauchy, dist.StudentT, etc.

Key Class: TorchDistributionMixin

class TorchDistributionMixin(Distribution):
    """
    Mixin class added to all torch.distributions classes when accessed
    through pyro.distributions. Adds Pyro-specific methods including
    enumerate_support integration and effect handler compatibility.
    """

I/O Contract

Inputs

Name Type Required Description
(distribution-specific) varies Yes Each distribution has its own parameters (e.g., loc/scale for Normal, concentration for Dirichlet)
validate_args bool No Whether to validate input arguments (default: False in Pyro)

Outputs

Name Type Description
return Distribution object A Pyro-compatible distribution that can be passed to pyro.sample
.sample() torch.Tensor Draw a random sample from the distribution
.log_prob(value) torch.Tensor Compute the log probability density/mass at the given value
.expand(batch_shape) Distribution Create a batch of independent distributions with the given shape
.to_event(n) Independent Reinterpret the rightmost n batch dimensions as event dimensions

Usage Examples

Common Priors for Bayesian Regression

import pyro
import pyro.distributions as dist

def bayesian_regression(x, y=None):
    # Normal prior on coefficients (weakly informative)
    beta = pyro.sample("beta", dist.Normal(0., 10.))

    # HalfCauchy prior on noise scale (heavy tails, positive)
    sigma = pyro.sample("sigma", dist.HalfCauchy(2.))

    # Likelihood
    mu = beta * x
    with pyro.plate("data", len(x)):
        pyro.sample("obs", dist.Normal(mu, sigma), obs=y)

Robust Regression with StudentT

import pyro
import pyro.distributions as dist

def robust_regression(x, y=None):
    beta = pyro.sample("beta", dist.Normal(0., 10.))
    sigma = pyro.sample("sigma", dist.HalfNormal(5.))
    # Low degrees of freedom for heavy tails (robust to outliers)
    nu = pyro.sample("nu", dist.Gamma(2., 0.1))

    mu = beta * x
    with pyro.plate("data", len(x)):
        pyro.sample("obs", dist.StudentT(nu, mu, sigma), obs=y)

Mixture Model with Dirichlet Prior

import torch
import pyro
import pyro.distributions as dist

def mixture_model(data, K=3):
    # Dirichlet prior over mixture weights
    weights = pyro.sample("weights", dist.Dirichlet(torch.ones(K)))

    # Priors for component parameters
    with pyro.plate("components", K):
        locs = pyro.sample("locs", dist.Normal(0., 10.))
        scales = pyro.sample("scales", dist.HalfNormal(5.))

    with pyro.plate("data", len(data)):
        assignment = pyro.sample("assignment", dist.Categorical(weights),
                                 infer={"enumerate": "parallel"})
        pyro.sample("obs", dist.Normal(locs[assignment], scales[assignment]),
                    obs=data)

Multivariate Prior with .expand() and .to_event()

import pyro
import pyro.distributions as dist

def multivariate_model(x, y=None):
    n_features = x.shape[1]
    # Expand a scalar prior to a vector and mark as single event
    beta = pyro.sample("beta",
                       dist.Normal(0., 1.).expand([n_features]).to_event(1))
    sigma = pyro.sample("sigma", dist.HalfCauchy(1.))

    mu = x @ beta
    with pyro.plate("data", x.shape[0]):
        pyro.sample("obs", dist.Normal(mu, sigma), obs=y)

External References

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment