Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro TorchDistribution Wrappers

From Leeroopedia


Knowledge Sources
Domains Probability_Distributions, Deep_Learning
Last Updated 2026-02-09 09:00 GMT

Overview

Pyro-compatible wrappers around standard PyTorch distributions, mixing in TorchDistributionMixin for seamless integration with the Pyro probabilistic programming framework.

Description

This module provides Pyro-enhanced versions of all standard PyTorch distributions by combining them with TorchDistributionMixin. Each wrapper class inherits from both the corresponding torch.distributions class and TorchDistributionMixin, enabling Pyro-specific features such as score_parts, to_event, expand_by, mask, and conjugate_update.

The module defines two categories of distributions:

Explicitly defined wrappers with custom enhancements:

  • Beta -- Adds conjugate_update for conjugate Bayesian updating.
  • Binomial -- Adds approximate sampling via clamped Poisson for large total_count values (controlled by approx_sample_thresh) and approximate log-probability using Sterling's approximation (controlled by approx_log_prob_tol). Uses a numerically stable log-prob via logits.
  • Categorical -- Optimizes log_prob for enumerated support by reshaping logits instead of using torch.gather, which is critical for Pyro's enumeration-based inference.
  • Dirichlet -- Adds conjugate_update and custom infer_shapes.
  • Gamma -- Adds conjugate_update.
  • Geometric -- Overrides log_prob with a numerically stable softplus-based implementation.
  • LogNormal -- Uses Pyro's Normal as the base distribution instead of PyTorch's Normal.
  • LowRankMultivariateNormal and MultivariateNormal -- Add custom infer_shapes methods.
  • Multinomial and OneHotCategorical -- Add custom infer_shapes.
  • Normal -- Thin wrapper with no additional methods.
  • Poisson -- Adds an is_sparse flag that optimizes log_prob computation for sparse count data.
  • Independent -- Overrides _validate_args to delegate to the base distribution, and adds conjugate_update.
  • Uniform -- Stores unbroadcasted bounds for proper support constraints.

Auto-generated wrappers: All remaining PyTorch distributions not explicitly listed are automatically wrapped at module load time using type(name, (Dist, TorchDistributionMixin), {}).

The module also registers Pyro settings for binomial_approx_sample_thresh and binomial_approx_log_prob_tol.

Usage

These wrappers are the standard way to use PyTorch distributions within Pyro models. Import them from pyro.distributions instead of torch.distributions to get full Pyro compatibility including enumeration support, masking, conjugate updates, and proper shape inference.

Code Reference

Source Location

Signature

class Beta(torch.distributions.Beta, TorchDistributionMixin): ...
class Binomial(torch.distributions.Binomial, TorchDistributionMixin): ...
class Categorical(torch.distributions.Categorical, TorchDistributionMixin): ...
class Dirichlet(torch.distributions.Dirichlet, TorchDistributionMixin): ...
class Gamma(torch.distributions.Gamma, TorchDistributionMixin): ...
class Geometric(torch.distributions.Geometric, TorchDistributionMixin): ...
class LogNormal(torch.distributions.LogNormal, TorchDistributionMixin): ...
class LowRankMultivariateNormal(torch.distributions.LowRankMultivariateNormal, TorchDistributionMixin): ...
class MultivariateNormal(torch.distributions.MultivariateNormal, TorchDistributionMixin): ...
class Multinomial(torch.distributions.Multinomial, TorchDistributionMixin): ...
class Normal(torch.distributions.Normal, TorchDistributionMixin): ...
class OneHotCategorical(torch.distributions.OneHotCategorical, TorchDistributionMixin): ...
class Poisson(torch.distributions.Poisson, TorchDistributionMixin): ...
class Independent(torch.distributions.Independent, TorchDistributionMixin): ...
class Uniform(torch.distributions.Uniform, TorchDistributionMixin): ...

Import

from pyro.distributions import Normal, Beta, Categorical, Poisson
# All standard PyTorch distributions are available
from pyro.distributions import (
    Bernoulli, Binomial, Cauchy, Chi2, ContinuousBernoulli, Dirichlet,
    Exponential, FisherSnedecor, Gamma, Geometric, Gumbel, HalfCauchy,
    HalfNormal, Independent, Kumaraswamy, Laplace, LKJCholesky, LogNormal,
    LogisticNormal, LowRankMultivariateNormal, MixtureSameFamily,
    Multinomial, MultivariateNormal, NegativeBinomial, OneHotCategorical,
    Pareto, RelaxedBernoulli, RelaxedOneHotCategorical, StudentT,
    TransformedDistribution, Uniform, VonMises, Weibull, Wishart,
)

I/O Contract

Binomial Special Inputs

Name Type Required Description
total_count int or torch.Tensor Yes Number of Bernoulli trials.
probs torch.Tensor No Probability of success (mutually exclusive with logits).
logits torch.Tensor No Log-odds of success (mutually exclusive with probs).
approx_sample_thresh float No Class attribute (default math.inf). Threshold on total_count above which sampling uses a clamped Poisson approximation.
approx_log_prob_tol float No Class attribute (default 0.0). Positive values trigger Sterling's approximation for log-prob.

Categorical Special Outputs

Name Type Description
log_prob torch.Tensor Optimized log probability that avoids torch.gather when value is from enumerate_support.
enumerate_support torch.Tensor Support values tagged with _pyro_categorical_support for optimization.

Poisson Special Inputs

Name Type Required Description
rate torch.Tensor Yes Rate parameter (lambda).
is_sparse bool No If True, optimizes log_prob for sparse count data.

Conjugate Update (Beta, Dirichlet, Gamma) Outputs

Name Type Description
updated Distribution Posterior distribution after conjugate update.
log_normalizer torch.Tensor Log normalizing constant of the update.

Usage Examples

Basic Distribution Usage

import torch
import pyro
import pyro.distributions as dist

# Normal distribution with Pyro features
normal = dist.Normal(torch.tensor(0.0), torch.tensor(1.0))
x = normal()  # calls rsample (reparameterized)

# Convert to event dimensions
mvn_like = dist.Normal(torch.zeros(3), torch.ones(3)).to_event(1)
print(mvn_like.batch_shape, mvn_like.event_shape)  # torch.Size([]) torch.Size([3])

Conjugate Bayesian Update

import pyro.distributions as dist

# Beta-Bernoulli conjugate update
prior = dist.Beta(2.0, 5.0)
likelihood = dist.Beta(10.0, 3.0)  # pseudo-observations
posterior, log_norm = prior.conjugate_update(likelihood)

Efficient Binomial for Large Populations

import pyro.distributions as dist

# Use Poisson approximation for large total_count
dist.Binomial.approx_sample_thresh = 1e6
binomial = dist.Binomial(total_count=1e8, probs=0.001)
samples = binomial.sample(torch.Size([1000]))

Sparse Poisson Log-Prob

import torch
import pyro.distributions as dist

# Efficient log_prob for sparse count data
poisson = dist.Poisson(rate=torch.tensor(0.01), is_sparse=True)
counts = torch.zeros(1000)
counts[42] = 3.0
log_p = poisson.log_prob(counts)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment