Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro OMTMultivariateNormal

From Leeroopedia


Knowledge Sources
Domains Probability_Distributions
Last Updated 2026-02-09 09:00 GMT

Overview

Description

OMTMultivariateNormal is a specialized multivariate normal (Gaussian) distribution that uses Optimal Mass Transport (OMT) gradients with respect to both the location (loc) and the Cholesky factor (scale_tril) parameters. It extends PyTorch's MultivariateNormal distribution, overriding only the rsample method to route through a custom autograd function.

The key innovation is in the gradient computation. While the standard reparameterization trick provides unbiased gradient estimates, the OMT approach can yield lower-variance gradients, at the cost of higher computational complexity. The gradient computation with respect to the Cholesky factor has O(D^3) cost, where D is the dimensionality.

The custom backward pass in _OMTMVNSample computes:

  1. The standard location gradient via summation over the leftmost dimensions.
  2. A gradient for the Cholesky factor L that involves computing the inverse of L, performing an SVD of the inverse covariance matrix, and constructing a correction term Y based on outer products in the SVD basis. The final Cholesky gradient is constrained to be lower-triangular.

The distribution is restricted to 1-dimensional loc and 2-dimensional scale_tril tensors (no batching support).

Usage

OMTMultivariateNormal is useful in stochastic variational inference when one desires lower-variance gradient estimates for the covariance parameters of a multivariate normal variational distribution. The reduced gradient variance can lead to faster convergence, though the O(D^3) cost makes it more expensive per iteration than the standard reparameterization trick for high-dimensional problems.

Code Reference

Source Location

Signature

class OMTMultivariateNormal(MultivariateNormal):
    def __init__(self, loc, scale_tril)

Import

from pyro.distributions import OMTMultivariateNormal

I/O Contract

Inputs

Parameter Type Description
loc torch.Tensor A 1-dimensional tensor of shape (D,) specifying the mean of the distribution.
scale_tril torch.Tensor A 2-dimensional lower-triangular tensor of shape (D, D) specifying the Cholesky factor of the covariance matrix.

Outputs

Method Return Type Description
rsample(sample_shape) torch.Tensor Returns a reparameterized sample of shape sample_shape + (D,) with OMT gradients flowing through both loc and scale_tril.
log_prob(value) torch.Tensor Returns the log probability density (inherited from MultivariateNormal).

Usage Examples

import torch
from pyro.distributions import OMTMultivariateNormal

# Create a 3-dimensional OMT multivariate normal
loc = torch.zeros(3, requires_grad=True)
scale_tril = torch.eye(3, requires_grad=True)

dist = OMTMultivariateNormal(loc, scale_tril)

# Draw a sample with OMT gradients
sample = dist.rsample()
print(sample.shape)  # torch.Size([3])

# Compute a loss and backpropagate
loss = sample.sum()
loss.backward()

# Gradients are available with lower variance than standard reparameterization
print(loc.grad)
print(scale_tril.grad)
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

# Using OMTMultivariateNormal as a variational distribution (guide)
def model():
    pyro.sample("z", dist.MultivariateNormal(torch.zeros(3), torch.eye(3)))

def guide():
    loc = pyro.param("loc", torch.zeros(3))
    scale_tril = pyro.param("scale_tril", torch.eye(3),
                            constraint=dist.constraints.lower_cholesky)
    pyro.sample("z", dist.OMTMultivariateNormal(loc, scale_tril))

# SVI with OMT gradients for lower-variance updates
svi = SVI(model, guide, Adam({"lr": 0.01}), loss=Trace_ELBO())

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment