Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro GammaGaussian

From Leeroopedia


Property Value
Module pyro.ops.gamma_gaussian
Source pyro/ops/gamma_gaussian.py
Lines 468
Classes Gamma, GammaGaussian
Functions gamma_and_mvn_to_gamma_gaussian, scale_mvn, matrix_and_mvn_to_gamma_gaussian, gamma_gaussian_tensordot
Dependencies torch, pyro.distributions, pyro.ops.tensor_utils

Overview

This module provides non-normalized Gamma-Gaussian distribution representations and operations used internally in Pyro for conjugate inference with scale mixtures. The central abstraction is the GammaGaussian class, which models a joint distribution over a continuous vector x and a positive scalar mixing variable s, where the Gaussian component's precision is scaled by s.

The Gamma class represents a non-normalized Gamma distribution in log-density space:

Gamma(concentration, rate) ~ (concentration - 1) * log(s) - rate * s

The GammaGaussian class uses an information (natural) parameterization with info_vec and precision rather than mean and covariance. This design choice enables stable computation with rank-deficient precision matrices (which may have zero eigenvalues). The log-density is:

alpha * log(s) + s * (-0.5 * x.T @ precision @ x + x.T @ info_vec - beta) + log_normalizer

where alpha and beta are reparameterized versions of the Gamma concentration and rate parameters. Conditioning on s yields a Gaussian p(x|s) ~ Gaussian(s * info_vec, s * precision).

Code Reference

Class: Gamma

A non-normalized Gamma distribution storing log_normalizer, concentration, and rate. Provides:

  • log_density(s): Evaluates the unnormalized log probability at value s.
  • logsumexp(): Integrates out the latent variable s by computing log_normalizer + lgamma(concentration) - concentration * log(rate).

Class: GammaGaussian

A non-normalized Gamma-Gaussian distribution parameterized by:

  • log_normalizer -- scalar normalization constant
  • info_vec -- information vector (precision @ mean)
  • precision -- precision matrix (may be rank-deficient)
  • alpha -- reparameterized Gamma shape: concentration + 0.5 * dim - 1
  • beta -- reparameterized Gamma rate: rate + 0.5 * info_vec.T @ inv(precision) @ info_vec

Key methods:

  • expand(batch_shape): Expands to the given batch shape.
  • reshape(batch_shape): Reshapes to the given batch shape.
  • cat(parts, dim): Static method to concatenate GammaGaussians along a batch dimension.
  • event_pad(left, right): Zero-pads the event dimension (info_vec and precision).
  • event_permute(perm): Permutes the event dimension using a tensor of indices.
  • __add__(other): Adds two GammaGaussians in log-density space (element-wise addition of all parameters).
  • log_density(value, s): Evaluates the full log density at a given point.
  • condition(value): Conditions on a trailing subset of state, reducing dimensionality.
  • marginalize(left, right): Marginalizes out variables from either side of the event dimension using Cholesky factorization.
  • compound(): Integrates out the latent multiplier s, yielding a MultivariateStudentT distribution.
  • event_logsumexp(): Integrates out all Gaussian latent state, returning a Gamma object.

Conversion Functions

  • gamma_and_mvn_to_gamma_gaussian(gamma, mvn): Converts a Gamma and MultivariateNormal pair into a GammaGaussian.
  • scale_mvn(mvn, s): Transforms a MVN by scaling its precision by s.
  • matrix_and_mvn_to_gamma_gaussian(matrix, mvn): Converts a noisy affine function y = x @ matrix + noise to a GammaGaussian.
  • gamma_gaussian_tensordot(x, y, dims): Computes the integral (tensor contraction) over two GammaGaussians, marginalizing out shared variables.

I/O Contract

Function/Method Input Output
GammaGaussian.__init__ log_normalizer: Tensor, info_vec: Tensor(..., D), precision: Tensor(..., D, D), alpha: Tensor, beta: Tensor GammaGaussian instance
condition(value) value: Tensor(..., K) where K <= D GammaGaussian with dim = D - K
marginalize(left, right) left: int, right: int GammaGaussian with reduced event dim
compound() (none) MultivariateStudentT distribution
event_logsumexp() (none) Gamma object
gamma_gaussian_tensordot(x, y, dims) Two GammaGaussian instances, dims: int GammaGaussian with dim = x.dim() + y.dim() - 2*dims

Usage Examples

import torch
from pyro.ops.gamma_gaussian import (
    GammaGaussian,
    gamma_and_mvn_to_gamma_gaussian,
    gamma_gaussian_tensordot,
)

# Create a GammaGaussian from a Gamma and MVN pair
gamma = torch.distributions.Gamma(torch.tensor(2.0), torch.tensor(1.0))
mvn = torch.distributions.MultivariateNormal(
    torch.zeros(3), torch.eye(3)
)
gg = gamma_and_mvn_to_gamma_gaussian(gamma, mvn)
print(gg.dim())  # 3

# Condition on the last 2 dimensions
value = torch.randn(2)
gg_cond = gg.condition(value)
print(gg_cond.dim())  # 1

# Marginalize out the first dimension
gg_marg = gg.marginalize(left=1)
print(gg_marg.dim())  # 2

# Integrate out the Gaussian part to get a Gamma
gamma_result = gg.event_logsumexp()

# Integrate out s to get a Student-T distribution
student_t = gg.compound()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment