Implementation:Pyro ppl Pyro OMTMultivariateNormal
| Knowledge Sources | |
|---|---|
| Domains | Probability_Distributions |
| Last Updated | 2026-02-09 09:00 GMT |
Overview
Description
OMTMultivariateNormal is a specialized multivariate normal (Gaussian) distribution that uses Optimal Mass Transport (OMT) gradients with respect to both the location (loc) and the Cholesky factor (scale_tril) parameters. It extends PyTorch's MultivariateNormal distribution, overriding only the rsample method to route through a custom autograd function.
The key innovation is in the gradient computation. While the standard reparameterization trick provides unbiased gradient estimates, the OMT approach can yield lower-variance gradients, at the cost of higher computational complexity. The gradient computation with respect to the Cholesky factor has O(D^3) cost, where D is the dimensionality.
The custom backward pass in _OMTMVNSample computes:
- The standard location gradient via summation over the leftmost dimensions.
- A gradient for the Cholesky factor
Lthat involves computing the inverse ofL, performing an SVD of the inverse covariance matrix, and constructing a correction termYbased on outer products in the SVD basis. The final Cholesky gradient is constrained to be lower-triangular.
The distribution is restricted to 1-dimensional loc and 2-dimensional scale_tril tensors (no batching support).
Usage
OMTMultivariateNormal is useful in stochastic variational inference when one desires lower-variance gradient estimates for the covariance parameters of a multivariate normal variational distribution. The reduced gradient variance can lead to faster convergence, though the O(D^3) cost makes it more expensive per iteration than the standard reparameterization trick for high-dimensional problems.
Code Reference
Source Location
- File:
pyro/distributions/omt_mvn.py - Repository: pyro-ppl/pyro
Signature
class OMTMultivariateNormal(MultivariateNormal):
def __init__(self, loc, scale_tril)
Import
from pyro.distributions import OMTMultivariateNormal
I/O Contract
Inputs
| Parameter | Type | Description |
|---|---|---|
loc |
torch.Tensor |
A 1-dimensional tensor of shape (D,) specifying the mean of the distribution.
|
scale_tril |
torch.Tensor |
A 2-dimensional lower-triangular tensor of shape (D, D) specifying the Cholesky factor of the covariance matrix.
|
Outputs
| Method | Return Type | Description |
|---|---|---|
rsample(sample_shape) |
torch.Tensor |
Returns a reparameterized sample of shape sample_shape + (D,) with OMT gradients flowing through both loc and scale_tril.
|
log_prob(value) |
torch.Tensor |
Returns the log probability density (inherited from MultivariateNormal).
|
Usage Examples
import torch
from pyro.distributions import OMTMultivariateNormal
# Create a 3-dimensional OMT multivariate normal
loc = torch.zeros(3, requires_grad=True)
scale_tril = torch.eye(3, requires_grad=True)
dist = OMTMultivariateNormal(loc, scale_tril)
# Draw a sample with OMT gradients
sample = dist.rsample()
print(sample.shape) # torch.Size([3])
# Compute a loss and backpropagate
loss = sample.sum()
loss.backward()
# Gradients are available with lower variance than standard reparameterization
print(loc.grad)
print(scale_tril.grad)
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
# Using OMTMultivariateNormal as a variational distribution (guide)
def model():
pyro.sample("z", dist.MultivariateNormal(torch.zeros(3), torch.eye(3)))
def guide():
loc = pyro.param("loc", torch.zeros(3))
scale_tril = pyro.param("scale_tril", torch.eye(3),
constraint=dist.constraints.lower_cholesky)
pyro.sample("z", dist.OMTMultivariateNormal(loc, scale_tril))
# SVI with OMT gradients for lower-variance updates
svi = SVI(model, guide, Adam({"lr": 0.01}), loss=Trace_ELBO())
Related Pages
- Pyro_ppl_Pyro_MultivariateStudentT - Multivariate Student's t-distribution, another multivariate distribution with Cholesky parameterization
- Pyro_ppl_Pyro_GaussianScaleMixture - Another distribution with custom pathwise derivative implementation
- Pyro_ppl_Pyro_LKJ - LKJ distribution for correlation matrices, often used with multivariate normals