Implementation:Pyro ppl Pyro GammaGaussian
| Property | Value |
|---|---|
| Module | pyro.ops.gamma_gaussian
|
| Source | pyro/ops/gamma_gaussian.py |
| Lines | 468 |
| Classes | Gamma, GammaGaussian
|
| Functions | gamma_and_mvn_to_gamma_gaussian, scale_mvn, matrix_and_mvn_to_gamma_gaussian, gamma_gaussian_tensordot
|
| Dependencies | torch, pyro.distributions, pyro.ops.tensor_utils
|
Overview
This module provides non-normalized Gamma-Gaussian distribution representations and operations used internally in Pyro for conjugate inference with scale mixtures. The central abstraction is the GammaGaussian class, which models a joint distribution over a continuous vector x and a positive scalar mixing variable s, where the Gaussian component's precision is scaled by s.
The Gamma class represents a non-normalized Gamma distribution in log-density space:
Gamma(concentration, rate) ~ (concentration - 1) * log(s) - rate * s
The GammaGaussian class uses an information (natural) parameterization with info_vec and precision rather than mean and covariance. This design choice enables stable computation with rank-deficient precision matrices (which may have zero eigenvalues). The log-density is:
alpha * log(s) + s * (-0.5 * x.T @ precision @ x + x.T @ info_vec - beta) + log_normalizer
where alpha and beta are reparameterized versions of the Gamma concentration and rate parameters. Conditioning on s yields a Gaussian p(x|s) ~ Gaussian(s * info_vec, s * precision).
Code Reference
Class: Gamma
A non-normalized Gamma distribution storing log_normalizer, concentration, and rate. Provides:
log_density(s): Evaluates the unnormalized log probability at values.logsumexp(): Integrates out the latent variablesby computinglog_normalizer + lgamma(concentration) - concentration * log(rate).
Class: GammaGaussian
A non-normalized Gamma-Gaussian distribution parameterized by:
log_normalizer-- scalar normalization constantinfo_vec-- information vector (precision @ mean)precision-- precision matrix (may be rank-deficient)alpha-- reparameterized Gamma shape:concentration + 0.5 * dim - 1beta-- reparameterized Gamma rate:rate + 0.5 * info_vec.T @ inv(precision) @ info_vec
Key methods:
expand(batch_shape): Expands to the given batch shape.reshape(batch_shape): Reshapes to the given batch shape.cat(parts, dim): Static method to concatenate GammaGaussians along a batch dimension.event_pad(left, right): Zero-pads the event dimension (info_vec and precision).event_permute(perm): Permutes the event dimension using a tensor of indices.__add__(other): Adds two GammaGaussians in log-density space (element-wise addition of all parameters).log_density(value, s): Evaluates the full log density at a given point.condition(value): Conditions on a trailing subset of state, reducing dimensionality.marginalize(left, right): Marginalizes out variables from either side of the event dimension using Cholesky factorization.compound(): Integrates out the latent multipliers, yielding a MultivariateStudentT distribution.event_logsumexp(): Integrates out all Gaussian latent state, returning a Gamma object.
Conversion Functions
gamma_and_mvn_to_gamma_gaussian(gamma, mvn): Converts a Gamma and MultivariateNormal pair into a GammaGaussian.scale_mvn(mvn, s): Transforms a MVN by scaling its precision bys.matrix_and_mvn_to_gamma_gaussian(matrix, mvn): Converts a noisy affine functiony = x @ matrix + noiseto a GammaGaussian.gamma_gaussian_tensordot(x, y, dims): Computes the integral (tensor contraction) over two GammaGaussians, marginalizing out shared variables.
I/O Contract
| Function/Method | Input | Output |
|---|---|---|
GammaGaussian.__init__ |
log_normalizer: Tensor, info_vec: Tensor(..., D), precision: Tensor(..., D, D), alpha: Tensor, beta: Tensor |
GammaGaussian instance
|
condition(value) |
value: Tensor(..., K) where K <= D |
GammaGaussian with dim = D - K
|
marginalize(left, right) |
left: int, right: int |
GammaGaussian with reduced event dim
|
compound() |
(none) | MultivariateStudentT distribution
|
event_logsumexp() |
(none) | Gamma object
|
gamma_gaussian_tensordot(x, y, dims) |
Two GammaGaussian instances, dims: int |
GammaGaussian with dim = x.dim() + y.dim() - 2*dims
|
Usage Examples
import torch
from pyro.ops.gamma_gaussian import (
GammaGaussian,
gamma_and_mvn_to_gamma_gaussian,
gamma_gaussian_tensordot,
)
# Create a GammaGaussian from a Gamma and MVN pair
gamma = torch.distributions.Gamma(torch.tensor(2.0), torch.tensor(1.0))
mvn = torch.distributions.MultivariateNormal(
torch.zeros(3), torch.eye(3)
)
gg = gamma_and_mvn_to_gamma_gaussian(gamma, mvn)
print(gg.dim()) # 3
# Condition on the last 2 dimensions
value = torch.randn(2)
gg_cond = gg.condition(value)
print(gg_cond.dim()) # 1
# Marginalize out the first dimension
gg_marg = gg.marginalize(left=1)
print(gg_marg.dim()) # 2
# Integrate out the Gaussian part to get a Gamma
gamma_result = gg.event_logsumexp()
# Integrate out s to get a Student-T distribution
student_t = gg.compound()
Related Pages
- Pyro_ppl_Pyro_Gaussian -- Non-normalized Gaussian operations (no mixing variable)
- Pyro_ppl_Pyro_TensorUtils -- Utility functions including
precision_to_scale_tril - Pyro_ppl_Pyro_Stats -- Statistical computation utilities