Implementation:Pyro ppl Pyro TorchDistribution Wrappers
| Knowledge Sources | |
|---|---|
| Domains | Probability_Distributions, Deep_Learning |
| Last Updated | 2026-02-09 09:00 GMT |
Overview
Pyro-compatible wrappers around standard PyTorch distributions, mixing in TorchDistributionMixin for seamless integration with the Pyro probabilistic programming framework.
Description
This module provides Pyro-enhanced versions of all standard PyTorch distributions by combining them with TorchDistributionMixin. Each wrapper class inherits from both the corresponding torch.distributions class and TorchDistributionMixin, enabling Pyro-specific features such as score_parts, to_event, expand_by, mask, and conjugate_update.
The module defines two categories of distributions:
Explicitly defined wrappers with custom enhancements:
- Beta -- Adds
conjugate_updatefor conjugate Bayesian updating. - Binomial -- Adds approximate sampling via clamped Poisson for large
total_countvalues (controlled byapprox_sample_thresh) and approximate log-probability using Sterling's approximation (controlled byapprox_log_prob_tol). Uses a numerically stable log-prob via logits. - Categorical -- Optimizes
log_probfor enumerated support by reshaping logits instead of usingtorch.gather, which is critical for Pyro's enumeration-based inference. - Dirichlet -- Adds
conjugate_updateand custominfer_shapes. - Gamma -- Adds
conjugate_update. - Geometric -- Overrides
log_probwith a numerically stable softplus-based implementation. - LogNormal -- Uses Pyro's Normal as the base distribution instead of PyTorch's Normal.
- LowRankMultivariateNormal and MultivariateNormal -- Add custom
infer_shapesmethods. - Multinomial and OneHotCategorical -- Add custom
infer_shapes. - Normal -- Thin wrapper with no additional methods.
- Poisson -- Adds an
is_sparseflag that optimizeslog_probcomputation for sparse count data. - Independent -- Overrides
_validate_argsto delegate to the base distribution, and addsconjugate_update. - Uniform -- Stores unbroadcasted bounds for proper support constraints.
Auto-generated wrappers: All remaining PyTorch distributions not explicitly listed are automatically wrapped at module load time using type(name, (Dist, TorchDistributionMixin), {}).
The module also registers Pyro settings for binomial_approx_sample_thresh and binomial_approx_log_prob_tol.
Usage
These wrappers are the standard way to use PyTorch distributions within Pyro models. Import them from pyro.distributions instead of torch.distributions to get full Pyro compatibility including enumeration support, masking, conjugate updates, and proper shape inference.
Code Reference
Source Location
- Repository: Pyro
- File: pyro/distributions/torch.py
Signature
class Beta(torch.distributions.Beta, TorchDistributionMixin): ...
class Binomial(torch.distributions.Binomial, TorchDistributionMixin): ...
class Categorical(torch.distributions.Categorical, TorchDistributionMixin): ...
class Dirichlet(torch.distributions.Dirichlet, TorchDistributionMixin): ...
class Gamma(torch.distributions.Gamma, TorchDistributionMixin): ...
class Geometric(torch.distributions.Geometric, TorchDistributionMixin): ...
class LogNormal(torch.distributions.LogNormal, TorchDistributionMixin): ...
class LowRankMultivariateNormal(torch.distributions.LowRankMultivariateNormal, TorchDistributionMixin): ...
class MultivariateNormal(torch.distributions.MultivariateNormal, TorchDistributionMixin): ...
class Multinomial(torch.distributions.Multinomial, TorchDistributionMixin): ...
class Normal(torch.distributions.Normal, TorchDistributionMixin): ...
class OneHotCategorical(torch.distributions.OneHotCategorical, TorchDistributionMixin): ...
class Poisson(torch.distributions.Poisson, TorchDistributionMixin): ...
class Independent(torch.distributions.Independent, TorchDistributionMixin): ...
class Uniform(torch.distributions.Uniform, TorchDistributionMixin): ...
Import
from pyro.distributions import Normal, Beta, Categorical, Poisson
# All standard PyTorch distributions are available
from pyro.distributions import (
Bernoulli, Binomial, Cauchy, Chi2, ContinuousBernoulli, Dirichlet,
Exponential, FisherSnedecor, Gamma, Geometric, Gumbel, HalfCauchy,
HalfNormal, Independent, Kumaraswamy, Laplace, LKJCholesky, LogNormal,
LogisticNormal, LowRankMultivariateNormal, MixtureSameFamily,
Multinomial, MultivariateNormal, NegativeBinomial, OneHotCategorical,
Pareto, RelaxedBernoulli, RelaxedOneHotCategorical, StudentT,
TransformedDistribution, Uniform, VonMises, Weibull, Wishart,
)
I/O Contract
Binomial Special Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| total_count | int or torch.Tensor | Yes | Number of Bernoulli trials. |
| probs | torch.Tensor | No | Probability of success (mutually exclusive with logits). |
| logits | torch.Tensor | No | Log-odds of success (mutually exclusive with probs). |
| approx_sample_thresh | float | No | Class attribute (default math.inf). Threshold on total_count above which sampling uses a clamped Poisson approximation.
|
| approx_log_prob_tol | float | No | Class attribute (default 0.0). Positive values trigger Sterling's approximation for log-prob.
|
Categorical Special Outputs
| Name | Type | Description |
|---|---|---|
| log_prob | torch.Tensor | Optimized log probability that avoids torch.gather when value is from enumerate_support.
|
| enumerate_support | torch.Tensor | Support values tagged with _pyro_categorical_support for optimization.
|
Poisson Special Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| rate | torch.Tensor | Yes | Rate parameter (lambda). |
| is_sparse | bool | No | If True, optimizes log_prob for sparse count data.
|
Conjugate Update (Beta, Dirichlet, Gamma) Outputs
| Name | Type | Description |
|---|---|---|
| updated | Distribution | Posterior distribution after conjugate update. |
| log_normalizer | torch.Tensor | Log normalizing constant of the update. |
Usage Examples
Basic Distribution Usage
import torch
import pyro
import pyro.distributions as dist
# Normal distribution with Pyro features
normal = dist.Normal(torch.tensor(0.0), torch.tensor(1.0))
x = normal() # calls rsample (reparameterized)
# Convert to event dimensions
mvn_like = dist.Normal(torch.zeros(3), torch.ones(3)).to_event(1)
print(mvn_like.batch_shape, mvn_like.event_shape) # torch.Size([]) torch.Size([3])
Conjugate Bayesian Update
import pyro.distributions as dist
# Beta-Bernoulli conjugate update
prior = dist.Beta(2.0, 5.0)
likelihood = dist.Beta(10.0, 3.0) # pseudo-observations
posterior, log_norm = prior.conjugate_update(likelihood)
Efficient Binomial for Large Populations
import pyro.distributions as dist
# Use Poisson approximation for large total_count
dist.Binomial.approx_sample_thresh = 1e6
binomial = dist.Binomial(total_count=1e8, probs=0.001)
samples = binomial.sample(torch.Size([1000]))
Sparse Poisson Log-Prob
import torch
import pyro.distributions as dist
# Efficient log_prob for sparse count data
poisson = dist.Poisson(rate=torch.tensor(0.01), is_sparse=True)
counts = torch.zeros(1000)
counts[42] = 3.0
log_p = poisson.log_prob(counts)