Implementation:Pyro ppl Pyro ConjugateDistributions
| Knowledge Sources | |
|---|---|
| Domains | Probability_Distributions |
| Last Updated | 2026-02-09 09:00 GMT |
Overview
Description
The conjugate module implements compound (conjugate) probability distributions that arise from integrating out a latent parameter from a prior-likelihood pair. Three compound distributions are provided:
- BetaBinomial -- A compound distribution formed by pairing a Beta prior with a Binomial likelihood. The unknown success probability is drawn from a
Beta(concentration1, concentration0)distribution beforetotal_countBernoulli trials are performed. The log-probability is computed analytically using the Beta function:log_binomial(n, k) + log_beta(k + a, n - k + b) - log_beta(a, b). This distribution supports enumeration of its discrete support and an optional approximate log-probability computation via a shifted Stirling's approximation controlled by theapprox_log_prob_tolclass attribute.
- DirichletMultinomial -- A compound distribution formed by pairing a Dirichlet prior with a Multinomial likelihood. The unknown class probabilities are drawn from a
Dirichlet(concentration)distribution beforetotal_countcategorical trials are performed. It supports a sparse mode (is_sparse=True) for efficient computation when observed counts are mostly zero.
- GammaPoisson -- A compound distribution formed by pairing a Gamma prior with a Poisson likelihood, also known as a Gamma-Poisson mixture. The unknown rate parameter is drawn from a
Gamma(concentration, rate)distribution. This is an alternative parameterization of the Negative Binomial distribution withconcentration = total_countandrate = (1 - probs) / probs.
All three distributions extend TorchDistribution and implement sample, log_prob, mean, variance, and expand methods.
Usage
These distributions are used in Bayesian modeling scenarios where conjugate prior-likelihood pairs naturally arise. They are commonly employed in topic modeling (DirichletMultinomial), overdispersed count data modeling (GammaPoisson / Negative Binomial), and binary outcome modeling with uncertainty over success probability (BetaBinomial).
Code Reference
Source Location
pyro/distributions/conjugate.py
Signature
class BetaBinomial(TorchDistribution):
def __init__(self, concentration1, concentration0, total_count=1, validate_args=None):
...
class DirichletMultinomial(TorchDistribution):
def __init__(self, concentration, total_count=1, is_sparse=False, validate_args=None):
...
class GammaPoisson(TorchDistribution):
def __init__(self, concentration, rate, validate_args=None):
...
Import
from pyro.distributions import BetaBinomial, DirichletMultinomial, GammaPoisson
I/O Contract
Inputs
| Class | Parameter | Type | Description |
|---|---|---|---|
BetaBinomial |
concentration1 |
float or torch.Tensor |
First concentration parameter (alpha) for the Beta prior. |
BetaBinomial |
concentration0 |
float or torch.Tensor |
Second concentration parameter (beta) for the Beta prior. |
BetaBinomial |
total_count |
float or torch.Tensor |
Number of Bernoulli trials. Defaults to 1. |
DirichletMultinomial |
concentration |
float or torch.Tensor |
Concentration parameter (alpha) for the Dirichlet prior. The last dimension is the event dimension. |
DirichletMultinomial |
total_count |
int or torch.Tensor |
Number of categorical trials. Defaults to 1. |
DirichletMultinomial |
is_sparse |
bool |
If True, assumes values are mostly zero for faster log_prob computation. Defaults to False. |
GammaPoisson |
concentration |
float or torch.Tensor |
Shape parameter (alpha) of the Gamma prior. |
GammaPoisson |
rate |
float or torch.Tensor |
Rate parameter (beta) of the Gamma prior. |
Outputs
| Method | Return Type | Description |
|---|---|---|
sample(sample_shape) |
torch.Tensor |
Draws samples by first sampling the latent parameter from the prior and then sampling from the likelihood. |
log_prob(value) |
torch.Tensor |
Computes the analytically marginalized log probability of the observed value. |
mean |
torch.Tensor |
Returns the mean of the compound distribution. |
variance |
torch.Tensor |
Returns the variance of the compound distribution. |
enumerate_support(expand) |
torch.Tensor |
(BetaBinomial only) Returns the full support {0, 1, ..., total_count}. |
Usage Examples
import torch
from pyro.distributions import BetaBinomial, DirichletMultinomial, GammaPoisson
# BetaBinomial example
bb = BetaBinomial(concentration1=2.0, concentration0=5.0, total_count=10)
sample = bb.sample((100,))
log_p = bb.log_prob(torch.tensor(3.0))
print("BetaBinomial mean:", bb.mean)
# DirichletMultinomial example
concentration = torch.tensor([1.0, 2.0, 3.0])
dm = DirichletMultinomial(concentration, total_count=10)
sample = dm.sample((50,))
log_p = dm.log_prob(torch.tensor([2.0, 3.0, 5.0]))
print("DirichletMultinomial mean:", dm.mean)
# GammaPoisson example
gp = GammaPoisson(concentration=5.0, rate=1.0)
sample = gp.sample((100,))
log_p = gp.log_prob(torch.tensor(3.0))
print("GammaPoisson mean:", gp.mean)
Related Pages
- Pyro_ppl_Pyro_Distribution_Base -- Base distribution class for all Pyro distributions
- Pyro_ppl_Pyro_Constraints -- Constraint definitions used by distribution parameters
- Pyro_ppl_Pyro_FoldedDistribution -- Another specialized distribution implementation