Principle:Pyro ppl Pyro Full Rank Variational Inference
Metadata
| Field | Value |
|---|---|
| Page Type | Principle |
| Knowledge Sources | Paper (Stochastic Variational Inference, Hoffman et al. 2013), Repo (Pyro) |
| Domains | Bayesian_Inference, Variational_Inference |
| Last Updated | 2026-02-09 12:00 GMT |
Overview
Full-rank variational inference approximates the posterior distribution with a multivariate Normal distribution that has a full (dense) covariance matrix, parameterized via a Cholesky factorization. Unlike the mean-field approximation, this captures correlations between latent variables, providing a more accurate posterior approximation at the cost of quadratically more parameters.
Description
Full-rank variational inference uses a multivariate Normal as the variational family:
q(z) = MultivariateNormal(mu, Sigma)
where mu is a d-dimensional mean vector and Sigma is a d x d positive definite covariance matrix. The covariance is parameterized through its Cholesky decomposition:
Sigma = L L^T
where L is a lower triangular matrix with positive diagonal entries. This parameterization ensures positive definiteness by construction and provides a numerically stable representation.
Comparison with Mean-Field
| Property | Mean-Field | Full-Rank |
|---|---|---|
| Covariance structure | Diagonal | Full (dense) |
| Parameters | O(d) (2d total) |
O(d^2) (d + d(d+1)/2 total)
|
| Posterior correlations | Not captured | Captured |
| Computational cost per step | O(d) |
O(d^2) to O(d^3)
|
| Memory | O(d) |
O(d^2)
|
| Posterior variance accuracy | Often underestimates | More accurate |
Cholesky Parameterization
The Cholesky factor L is decomposed into a diagonal scaling and a unit lower triangular matrix:
L = diag(scale) * L_unit
where scale is a vector of positive values (the marginal standard deviations) and L_unit is a lower triangular matrix with ones on the diagonal. This decomposition separates the scale of each variable from the correlation structure, improving optimization dynamics.
Flattened Latent Space
Full-rank VI requires all latent variables to be combined into a single vector. The implementation:
- Transforms each latent variable to unconstrained space using the appropriate bijective transform.
- Flattens all unconstrained latent values into a single vector of dimension
d. - Applies the multivariate Normal in this flattened space.
- Unflattens and transforms back to constrained space for each individual latent site.
This flattening is necessary because a full covariance matrix must operate on a single vector space.
Tradeoffs
- Captures correlations: The full covariance matrix models posterior dependencies between all pairs of latent variables, leading to tighter ELBO bounds and more accurate posterior approximations.
- Quadratic scaling: The number of parameters grows as
d(d+1)/2 + d, making this approach impractical for models with very large numbers of latent variables (e.g., thousands or more). - Better calibrated uncertainty: By modeling correlations, full-rank VI typically provides better calibrated credible regions compared to mean-field, which tends to underestimate posterior variance.
- Slower convergence: The larger parameter space can make optimization slower and more sensitive to learning rate and initialization.
Usage
Full-rank variational inference is applied when:
- Posterior correlations matter: Models where latent variables are strongly correlated under the posterior (e.g., hierarchical models with shared hyperparameters).
- The latent space is moderate-dimensional: Typically up to a few hundred dimensions, where the
O(d^2)parameter cost is manageable. - Accuracy is prioritized over speed: Applications requiring well-calibrated posterior uncertainty rather than just point estimates or marginal variances.
- Mean-field is insufficient: When diagnostic checks (e.g., comparing ELBO values or posterior predictive checks) indicate that the mean-field approximation is too crude.
In Pyro, full-rank variational inference is implemented by the AutoMultivariateNormal guide, which automatically constructs a multivariate Normal guide with Cholesky-parameterized covariance over the flattened latent space.
Theoretical Basis
Multivariate Normal Variational Family
The full-rank Gaussian variational family is:
Q_FR = { q : q(z) = MultivariateNormal(mu, Sigma) }
The ELBO for this family is:
ELBO = E_q[log p(x, z)] - E_q[log q(z)]
The entropy term has a closed form:
H[q] = (d/2) log(2 * pi * e) + (1/2) log det(Sigma) = (d/2) log(2 * pi * e) + sum_i log L_ii
where L_ii are the diagonal entries of the Cholesky factor. This makes the entropy gradient straightforward to compute.
Reparameterization
Samples from the multivariate Normal are generated via the reparameterization trick:
z = mu + L * epsilon, where epsilon ~ Normal(0, I)
This allows gradients of the ELBO with respect to mu and L to flow through the sampling operation, enabling efficient stochastic gradient optimization.
Relationship to Laplace Approximation
The full-rank Gaussian approximation is closely related to the Laplace approximation, which fits a Gaussian at the MAP estimate using the Hessian of the negative log posterior. However, full-rank VI optimizes both the mean and covariance jointly to maximize the ELBO, rather than fixing the mean at the mode. This typically yields a better approximation, especially when the posterior is skewed or when the mode and mean differ significantly.
Expressiveness Hierarchy
The hierarchy of Gaussian variational families in order of increasing expressiveness:
- Delta (MAP):
dparameters, no uncertainty. - Mean-field Normal:
2dparameters, diagonal covariance. - Low-rank Normal:
d + d*r + rparameters, low-rank plus diagonal covariance. - Full-rank Normal:
d + d(d+1)/2parameters, full covariance.
Each step up the hierarchy provides a more flexible approximation at increased computational cost.