Implementation:Pyro ppl Pyro Gaussian
| Property | Value |
|---|---|
| Module | pyro.ops.gaussian
|
| Source | pyro/ops/gaussian.py |
| Lines | 715 |
| Classes | Gaussian, AffineNormal
|
| Functions | mvn_to_gaussian, matrix_and_gaussian_to_gaussian, matrix_and_mvn_to_gaussian, gaussian_tensordot, sequential_gaussian_tensordot, sequential_gaussian_filter_sample
|
| Dependencies | torch, pyro.distributions.util, pyro.ops.tensor_utils
|
Overview
This module provides the core non-normalized Gaussian distribution class and related operations used extensively in Pyro for variable elimination, belief propagation, and Kalman filtering. The Gaussian class uses an information parameterization (natural parameters) with info_vec (precision-weighted mean) and precision matrix, which enables numerically stable operations even with rank-deficient matrices.
The key design principle is that the precision matrix may have zero eigenvalues, making it impossible to work directly with a covariance matrix. The information parameterization avoids this issue entirely.
The log density of a Gaussian is:
-0.5 * value.T @ precision @ value + value.T @ info_vec + log_normalizer
The module also provides AffineNormal, an efficient specialization for conditional diagonal normal distributions where the mean is an affine function of another variable.
Code Reference
Class: Gaussian
A non-normalized Gaussian parameterized by log_normalizer, info_vec, and precision.
Core methods:
dim(): Returns the event dimension.batch_shape: Lazy property computing broadcasted batch shape.expand(batch_shape): Expands to a given batch shape.reshape(batch_shape): Reshapes to a given batch shape.cat(parts, dim): Concatenates Gaussians along a batch dimension.event_pad(left, right): Zero-pads the event dimension.event_permute(perm): Permutes event dimensions.__add__(other): Adds in log-density space. Supports Gaussian, int, float, and Tensor.__sub__(other): Subtracts a scalar from the log normalizer.log_density(value): Evaluates log density at a point.rsample(sample_shape, noise): Draws reparameterized samples via Cholesky solve.condition(value): Conditions on a trailing subset of state, preserving density normalization.left_condition(value): Conditions on a leading subset of state.marginalize(left, right): Marginalizes out variables on either side using Cholesky decomposition.event_logsumexp(): Integrates out all latent state, returning a scalar log normalizer.
Class: AffineNormal
An efficient representation for a conditional diagonal normal Y = X @ matrix + Normal(loc, scale).
Methods:
condition(value): If conditioning on Y (full), computes precision and info_vec efficiently. Otherwise falls back to full Gaussian.left_condition(value): If conditioning on X (full), returns an AffineNormal with zero-dim matrix and shifted loc.rsample(sample_shape, noise): Reparameterized sampling (only when matrix has zero rows).to_gaussian(): Converts to a full Gaussian representation.
Module-level Functions
mvn_to_gaussian(mvn): Converts aMultivariateNormalorIndependent(Normal)to a Gaussian.matrix_and_gaussian_to_gaussian(matrix, y_gaussian): Constructs a conditional Gaussian forp(y|x)wherey - x @ matrix ~ y_gaussian.matrix_and_mvn_to_gaussian(matrix, mvn): Converts a noisy affine function to a Gaussian. ReturnsAffineNormalfor diagonal MVN.gaussian_tensordot(x, y, dims): Computes(x @ y)(a,c) = log(integral(exp(x(a,b) + y(b,c)), b))using Cholesky-based marginalization.sequential_gaussian_tensordot(gaussian): Performs parallel-scan sequential tensor contraction along the rightmost batch dimension, computingx[0] @ x[1] @ ... @ x[T-1].sequential_gaussian_filter_sample(init, trans, sample_shape, noise): Draws reparameterized samples from a Markov product via parallel-scan forward-filter backward-sample.
I/O Contract
| Function/Method | Input | Output |
|---|---|---|
Gaussian.__init__ |
log_normalizer: Tensor, info_vec: Tensor(..., D), precision: Tensor(..., D, D) |
Gaussian instance
|
condition(value) |
value: Tensor(..., K) |
Gaussian with dim = D - K
|
marginalize(left, right) |
Integers specifying variables to remove | Gaussian with reduced dim
|
event_logsumexp() |
(none) | Tensor (scalar log normalizer per batch)
|
rsample(sample_shape) |
sample_shape: torch.Size |
Tensor(sample_shape + batch_shape + (D,))
|
gaussian_tensordot(x, y, dims) |
Two Gaussian, dims: int |
Gaussian with dim = x.dim() + y.dim() - 2*dims
|
sequential_gaussian_filter_sample |
init: Gaussian, trans: Gaussian, sample_shape, noise |
Tensor(sample_shape + batch_shape + (duration, state_dim))
|
Usage Examples
import torch
from pyro.ops.gaussian import Gaussian, mvn_to_gaussian, gaussian_tensordot
# Create a Gaussian from a MultivariateNormal
mvn = torch.distributions.MultivariateNormal(
torch.zeros(3), torch.eye(3)
)
g = mvn_to_gaussian(mvn)
print(g.dim()) # 3
# Evaluate log density
value = torch.randn(3)
log_p = g.log_density(value)
# Condition on last 2 variables
g_cond = g.condition(torch.randn(2))
print(g_cond.dim()) # 1
# Marginalize out the first variable
g_marg = g.marginalize(left=1)
print(g_marg.dim()) # 2
# Draw reparameterized samples
samples = g.rsample(torch.Size([100]))
print(samples.shape) # torch.Size([100, 3])
# Tensor contraction of two Gaussians
g1 = mvn_to_gaussian(torch.distributions.MultivariateNormal(
torch.zeros(4), torch.eye(4)
))
g2 = mvn_to_gaussian(torch.distributions.MultivariateNormal(
torch.zeros(4), torch.eye(4)
))
g12 = gaussian_tensordot(g1, g2, dims=2)
print(g12.dim()) # 4 (= 4 + 4 - 2*2)
Related Pages
- Pyro_ppl_Pyro_GammaGaussian -- Gamma-Gaussian with mixing variable
- Pyro_ppl_Pyro_TensorUtils -- Low-level tensor utilities (Cholesky, triangular solve)
- Pyro_ppl_Pyro_MaternKernel -- State space GP models using Gaussian operations