Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro Gaussian

From Leeroopedia


Property Value
Module pyro.ops.gaussian
Source pyro/ops/gaussian.py
Lines 715
Classes Gaussian, AffineNormal
Functions mvn_to_gaussian, matrix_and_gaussian_to_gaussian, matrix_and_mvn_to_gaussian, gaussian_tensordot, sequential_gaussian_tensordot, sequential_gaussian_filter_sample
Dependencies torch, pyro.distributions.util, pyro.ops.tensor_utils

Overview

This module provides the core non-normalized Gaussian distribution class and related operations used extensively in Pyro for variable elimination, belief propagation, and Kalman filtering. The Gaussian class uses an information parameterization (natural parameters) with info_vec (precision-weighted mean) and precision matrix, which enables numerically stable operations even with rank-deficient matrices.

The key design principle is that the precision matrix may have zero eigenvalues, making it impossible to work directly with a covariance matrix. The information parameterization avoids this issue entirely.

The log density of a Gaussian is:

-0.5 * value.T @ precision @ value + value.T @ info_vec + log_normalizer

The module also provides AffineNormal, an efficient specialization for conditional diagonal normal distributions where the mean is an affine function of another variable.

Code Reference

Class: Gaussian

A non-normalized Gaussian parameterized by log_normalizer, info_vec, and precision.

Core methods:

  • dim(): Returns the event dimension.
  • batch_shape: Lazy property computing broadcasted batch shape.
  • expand(batch_shape): Expands to a given batch shape.
  • reshape(batch_shape): Reshapes to a given batch shape.
  • cat(parts, dim): Concatenates Gaussians along a batch dimension.
  • event_pad(left, right): Zero-pads the event dimension.
  • event_permute(perm): Permutes event dimensions.
  • __add__(other): Adds in log-density space. Supports Gaussian, int, float, and Tensor.
  • __sub__(other): Subtracts a scalar from the log normalizer.
  • log_density(value): Evaluates log density at a point.
  • rsample(sample_shape, noise): Draws reparameterized samples via Cholesky solve.
  • condition(value): Conditions on a trailing subset of state, preserving density normalization.
  • left_condition(value): Conditions on a leading subset of state.
  • marginalize(left, right): Marginalizes out variables on either side using Cholesky decomposition.
  • event_logsumexp(): Integrates out all latent state, returning a scalar log normalizer.

Class: AffineNormal

An efficient representation for a conditional diagonal normal Y = X @ matrix + Normal(loc, scale).

Methods:

  • condition(value): If conditioning on Y (full), computes precision and info_vec efficiently. Otherwise falls back to full Gaussian.
  • left_condition(value): If conditioning on X (full), returns an AffineNormal with zero-dim matrix and shifted loc.
  • rsample(sample_shape, noise): Reparameterized sampling (only when matrix has zero rows).
  • to_gaussian(): Converts to a full Gaussian representation.

Module-level Functions

  • mvn_to_gaussian(mvn): Converts a MultivariateNormal or Independent(Normal) to a Gaussian.
  • matrix_and_gaussian_to_gaussian(matrix, y_gaussian): Constructs a conditional Gaussian for p(y|x) where y - x @ matrix ~ y_gaussian.
  • matrix_and_mvn_to_gaussian(matrix, mvn): Converts a noisy affine function to a Gaussian. Returns AffineNormal for diagonal MVN.
  • gaussian_tensordot(x, y, dims): Computes (x @ y)(a,c) = log(integral(exp(x(a,b) + y(b,c)), b)) using Cholesky-based marginalization.
  • sequential_gaussian_tensordot(gaussian): Performs parallel-scan sequential tensor contraction along the rightmost batch dimension, computing x[0] @ x[1] @ ... @ x[T-1].
  • sequential_gaussian_filter_sample(init, trans, sample_shape, noise): Draws reparameterized samples from a Markov product via parallel-scan forward-filter backward-sample.

I/O Contract

Function/Method Input Output
Gaussian.__init__ log_normalizer: Tensor, info_vec: Tensor(..., D), precision: Tensor(..., D, D) Gaussian instance
condition(value) value: Tensor(..., K) Gaussian with dim = D - K
marginalize(left, right) Integers specifying variables to remove Gaussian with reduced dim
event_logsumexp() (none) Tensor (scalar log normalizer per batch)
rsample(sample_shape) sample_shape: torch.Size Tensor(sample_shape + batch_shape + (D,))
gaussian_tensordot(x, y, dims) Two Gaussian, dims: int Gaussian with dim = x.dim() + y.dim() - 2*dims
sequential_gaussian_filter_sample init: Gaussian, trans: Gaussian, sample_shape, noise Tensor(sample_shape + batch_shape + (duration, state_dim))

Usage Examples

import torch
from pyro.ops.gaussian import Gaussian, mvn_to_gaussian, gaussian_tensordot

# Create a Gaussian from a MultivariateNormal
mvn = torch.distributions.MultivariateNormal(
    torch.zeros(3), torch.eye(3)
)
g = mvn_to_gaussian(mvn)
print(g.dim())  # 3

# Evaluate log density
value = torch.randn(3)
log_p = g.log_density(value)

# Condition on last 2 variables
g_cond = g.condition(torch.randn(2))
print(g_cond.dim())  # 1

# Marginalize out the first variable
g_marg = g.marginalize(left=1)
print(g_marg.dim())  # 2

# Draw reparameterized samples
samples = g.rsample(torch.Size([100]))
print(samples.shape)  # torch.Size([100, 3])

# Tensor contraction of two Gaussians
g1 = mvn_to_gaussian(torch.distributions.MultivariateNormal(
    torch.zeros(4), torch.eye(4)
))
g2 = mvn_to_gaussian(torch.distributions.MultivariateNormal(
    torch.zeros(4), torch.eye(4)
))
g12 = gaussian_tensordot(g1, g2, dims=2)
print(g12.dim())  # 4 (= 4 + 4 - 2*2)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment