Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Pyro ppl Pyro AutoMultivariateNormal Guide

From Leeroopedia


Metadata

Field Value
Sources Repo: Pyro
Domains Bayesian_Inference, Variational_Inference
Updated 2026-02-09
Type Class

Overview

AutoMultivariateNormal is an automatic guide that constructs a multivariate Normal variational distribution with a full covariance matrix over the entire flattened latent space. It uses a Cholesky parameterization to maintain positive definiteness and captures posterior correlations between all pairs of latent variables.

Description

The AutoMultivariateNormal class inherits from AutoContinuous and provides a full-rank Gaussian variational family. It flattens all continuous latent variables in the model into a single vector and applies a multivariate Normal distribution with Cholesky-factored covariance.

When first called, the guide traces the model to discover latent sites, transforms them to unconstrained space, and flattens them into a single latent vector of dimension d. It then creates three sets of learnable parameters:

  • loc: A d-dimensional mean vector, initialized via init_loc_fn (default: init_to_median).
  • scale: A d-dimensional positive vector of per-variable scale factors, initialized to init_scale (default: 0.1), constrained via softplus_positive.
  • scale_tril: A d x d unit lower triangular matrix (ones on the diagonal), constrained via unit_lower_cholesky.

The effective Cholesky factor of the covariance matrix is computed as:

L = diag(scale) * scale_tril

This separates the marginal scales from the correlation structure, yielding the covariance:

Sigma = L L^T

During the forward pass, sampling uses the reparameterization trick via the LowerCholeskyAffine transform. The flattened samples are then unflattened and transformed back to the constrained space of each individual latent site.

Usage

Import this class when you need a variational approximation that captures posterior correlations between latent variables, particularly for models with moderate-dimensional latent spaces where the O(d^2) parameter cost is acceptable.

Code Reference

Source Location

Repository
pyro-ppl/pyro
File
pyro/infer/autoguide/guides.py
Lines
L844--907

Signature

class AutoMultivariateNormal(AutoContinuous):
    scale_constraint = constraints.softplus_positive
    scale_tril_constraint = constraints.unit_lower_cholesky

    def __init__(self, model, init_loc_fn=init_to_median, init_scale=0.1):

Import

from pyro.infer.autoguide import AutoMultivariateNormal

I/O Contract

Inputs

Parameter Type Required Description
model callable Yes A Pyro model function containing pyro.sample statements for continuous latent variables
init_loc_fn callable No Per-site initialization function for the mean vector (default: init_to_median)
init_scale float No Initial scale for the diagonal of the Cholesky factor (default: 0.1, must be > 0)

Outputs

Name Type Description
return value dict Dictionary mapping sample site names (str) to sampled latent values (Tensor), transformed from the joint multivariate Normal back to constrained space per site

Internal Parameters Created

Attribute Type Shape Description
self.loc nn.Parameter (d,) Mean vector of the multivariate Normal in unconstrained flattened space
self.scale PyroParam (d,) Per-dimension scale factors (positive, via softplus_positive)
self.scale_tril PyroParam (d, d) Unit lower triangular correlation matrix (via unit_lower_cholesky)

Key Methods

Method Returns Description
get_posterior() MultivariateNormal Returns the full multivariate Normal posterior with scale_tril = diag(scale) * scale_tril
get_transform() LowerCholeskyAffine Returns the affine transform for the reparameterization trick
get_base_dist() Normal(0, 1).to_event(1) Returns a standard Normal base distribution for sampling

Usage Examples

Basic Usage with SVI

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoMultivariateNormal

# Define a model with correlated latent variables
def model(data):
    loc = pyro.sample("loc", dist.Normal(0.0, 10.0))
    scale = pyro.sample("scale", dist.LogNormal(0.0, 1.0))
    with pyro.plate("data", len(data)):
        pyro.sample("obs", dist.Normal(loc, scale), obs=data)

# Construct the full-rank guide
guide = AutoMultivariateNormal(model)

# Set up SVI
optimizer = pyro.optim.Adam({"lr": 0.005})
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

# Training loop
data = torch.randn(100) + 3.0
for step in range(2000):
    loss = svi.step(data)

Extracting Posterior Covariance

# After training, access the full posterior distribution
posterior = guide.get_posterior()

# Mean vector
print("Mean:", posterior.loc)

# Full covariance matrix
print("Covariance:", posterior.covariance_matrix)

# Correlation between latent variables
scale_tril = guide.scale[..., None] * guide.scale_tril
cov = scale_tril @ scale_tril.T
std = cov.diag().sqrt()
corr = cov / (std[:, None] * std[None, :])
print("Correlation matrix:", corr)

Quantile Extraction

# Get quantiles of the posterior (inherited from AutoContinuous)
quantiles = guide.quantiles([0.05, 0.5, 0.95])
for name, values in quantiles.items():
    print(f"{name}: 5%={values[0].item():.3f}, "
          f"median={values[1].item():.3f}, "
          f"95%={values[2].item():.3f}")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment