Implementation:Pyro ppl Pyro AutoMultivariateNormal Guide
Metadata
| Field | Value |
|---|---|
| Sources | Repo: Pyro |
| Domains | Bayesian_Inference, Variational_Inference |
| Updated | 2026-02-09 |
| Type | Class |
Overview
AutoMultivariateNormal is an automatic guide that constructs a multivariate Normal variational distribution with a full covariance matrix over the entire flattened latent space. It uses a Cholesky parameterization to maintain positive definiteness and captures posterior correlations between all pairs of latent variables.
Description
The AutoMultivariateNormal class inherits from AutoContinuous and provides a full-rank Gaussian variational family. It flattens all continuous latent variables in the model into a single vector and applies a multivariate Normal distribution with Cholesky-factored covariance.
When first called, the guide traces the model to discover latent sites, transforms them to unconstrained space, and flattens them into a single latent vector of dimension d. It then creates three sets of learnable parameters:
loc: Ad-dimensional mean vector, initialized viainit_loc_fn(default:init_to_median).scale: Ad-dimensional positive vector of per-variable scale factors, initialized toinit_scale(default: 0.1), constrained viasoftplus_positive.scale_tril: Ad x dunit lower triangular matrix (ones on the diagonal), constrained viaunit_lower_cholesky.
The effective Cholesky factor of the covariance matrix is computed as:
L = diag(scale) * scale_tril
This separates the marginal scales from the correlation structure, yielding the covariance:
Sigma = L L^T
During the forward pass, sampling uses the reparameterization trick via the LowerCholeskyAffine transform. The flattened samples are then unflattened and transformed back to the constrained space of each individual latent site.
Usage
Import this class when you need a variational approximation that captures posterior correlations between latent variables, particularly for models with moderate-dimensional latent spaces where the O(d^2) parameter cost is acceptable.
Code Reference
Source Location
- Repository
pyro-ppl/pyro- File
pyro/infer/autoguide/guides.py- Lines
- L844--907
Signature
class AutoMultivariateNormal(AutoContinuous):
scale_constraint = constraints.softplus_positive
scale_tril_constraint = constraints.unit_lower_cholesky
def __init__(self, model, init_loc_fn=init_to_median, init_scale=0.1):
Import
from pyro.infer.autoguide import AutoMultivariateNormal
I/O Contract
Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
model |
callable | Yes | A Pyro model function containing pyro.sample statements for continuous latent variables
|
init_loc_fn |
callable | No | Per-site initialization function for the mean vector (default: init_to_median)
|
init_scale |
float | No | Initial scale for the diagonal of the Cholesky factor (default: 0.1, must be > 0) |
Outputs
| Name | Type | Description |
|---|---|---|
| return value | dict | Dictionary mapping sample site names (str) to sampled latent values (Tensor), transformed from the joint multivariate Normal back to constrained space per site |
Internal Parameters Created
| Attribute | Type | Shape | Description |
|---|---|---|---|
self.loc |
nn.Parameter | (d,) |
Mean vector of the multivariate Normal in unconstrained flattened space |
self.scale |
PyroParam | (d,) |
Per-dimension scale factors (positive, via softplus_positive)
|
self.scale_tril |
PyroParam | (d, d) |
Unit lower triangular correlation matrix (via unit_lower_cholesky)
|
Key Methods
| Method | Returns | Description |
|---|---|---|
get_posterior() |
MultivariateNormal |
Returns the full multivariate Normal posterior with scale_tril = diag(scale) * scale_tril
|
get_transform() |
LowerCholeskyAffine |
Returns the affine transform for the reparameterization trick |
get_base_dist() |
Normal(0, 1).to_event(1) |
Returns a standard Normal base distribution for sampling |
Usage Examples
Basic Usage with SVI
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoMultivariateNormal
# Define a model with correlated latent variables
def model(data):
loc = pyro.sample("loc", dist.Normal(0.0, 10.0))
scale = pyro.sample("scale", dist.LogNormal(0.0, 1.0))
with pyro.plate("data", len(data)):
pyro.sample("obs", dist.Normal(loc, scale), obs=data)
# Construct the full-rank guide
guide = AutoMultivariateNormal(model)
# Set up SVI
optimizer = pyro.optim.Adam({"lr": 0.005})
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
# Training loop
data = torch.randn(100) + 3.0
for step in range(2000):
loss = svi.step(data)
Extracting Posterior Covariance
# After training, access the full posterior distribution
posterior = guide.get_posterior()
# Mean vector
print("Mean:", posterior.loc)
# Full covariance matrix
print("Covariance:", posterior.covariance_matrix)
# Correlation between latent variables
scale_tril = guide.scale[..., None] * guide.scale_tril
cov = scale_tril @ scale_tril.T
std = cov.diag().sqrt()
corr = cov / (std[:, None] * std[None, :])
print("Correlation matrix:", corr)
Quantile Extraction
# Get quantiles of the posterior (inherited from AutoContinuous)
quantiles = guide.quantiles([0.05, 0.5, 0.95])
for name, values in quantiles.items():
print(f"{name}: 5%={values[0].item():.3f}, "
f"median={values[1].item():.3f}, "
f"95%={values[2].item():.3f}")