Implementation:Pyro ppl Pyro AVFMultivariateNormal
| Knowledge Sources | |
|---|---|
| Domains | Probability_Distributions |
| Last Updated | 2026-02-09 09:00 GMT |
Overview
Description
AVFMultivariateNormal is a multivariate normal (Gaussian) distribution that incorporates transport-equation-inspired control variates known as adaptive velocity fields (AVF). It extends Pyro's standard MultivariateNormal distribution by introducing a learned control variate parameter that reduces the variance of gradient estimates during stochastic optimization.
The distribution maintains the same forward sampling semantics as a standard multivariate normal but overrides the reparameterized sampling (rsample) method with a custom autograd function (_AVFMVNSample). During the backward pass, this custom function applies infinitesimal rotation-based control variates modulated by the control_var parameter to yield lower-variance pathwise gradient estimates. The control variate parameter is a 2 x L x D tensor (where L is an arbitrary positive integer and D is the dimensionality of the distribution) that must be learned concurrently with the loc and scale_tril parameters.
The implementation uses a custom torch.autograd.Function subclass (_AVFMVNSample) that defines explicit forward and backward passes. The forward pass performs the standard reparameterization trick (sampling white noise and transforming via the Cholesky factor), while the backward pass computes modified gradients that include the control variate correction terms.
Usage
This distribution is used in variational inference settings where reducing the variance of gradient estimates is critical for stable and efficient optimization. It is particularly useful when working with multivariate normal variational families in stochastic variational inference (SVI). The control_var parameter should be initialized and optimized alongside the distribution parameters using a separate or joint optimizer.
Code Reference
Source Location
pyro/distributions/avf_mvn.py
Signature
class AVFMultivariateNormal(MultivariateNormal):
def __init__(self, loc, scale_tril, control_var):
...
Import
from pyro.distributions import AVFMultivariateNormal
I/O Contract
Inputs
| Parameter | Type | Description |
|---|---|---|
loc |
torch.Tensor |
D-dimensional mean vector. Must be 1-dimensional. |
scale_tril |
torch.Tensor |
Cholesky factor of the covariance matrix. Must be a D x D lower-triangular matrix. |
control_var |
torch.Tensor |
A 2 x L x D tensor that parameterizes the control variate. L is an arbitrary positive integer. This parameter needs to be learned (adapted) to achieve lower variance gradients. |
Outputs
| Method | Return Type | Description |
|---|---|---|
rsample(sample_shape) |
torch.Tensor |
Draws reparameterized samples from the distribution with adaptive velocity field control variates applied during the backward pass. |
log_prob(value) |
torch.Tensor |
Evaluates the log probability density at the given value (inherited from MultivariateNormal).
|
Usage Examples
import torch
from pyro.distributions import AVFMultivariateNormal
D = 5
L = 3
loc = torch.zeros(D)
scale_tril = torch.eye(D)
control_var = torch.tensor(0.1 * torch.ones(2, L, D), requires_grad=True)
opt_cv = torch.optim.Adam([control_var], lr=0.1, betas=(0.5, 0.999))
for _ in range(1000):
d = AVFMultivariateNormal(loc, scale_tril, control_var)
z = d.rsample()
cost = torch.pow(z, 2.0).sum()
cost.backward()
opt_cv.step()
opt_cv.zero_grad()
Related Pages
- Pyro_ppl_Pyro_Distribution_Base -- Base distribution class for all Pyro distributions
- Pyro_ppl_Pyro_MixtureOfDiagNormals -- Another distribution with custom pathwise derivatives
- Pyro_ppl_Pyro_Constraints -- Constraint definitions used by distribution parameters