Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro AVFMultivariateNormal

From Leeroopedia


Knowledge Sources
Domains Probability_Distributions
Last Updated 2026-02-09 09:00 GMT

Overview

Description

AVFMultivariateNormal is a multivariate normal (Gaussian) distribution that incorporates transport-equation-inspired control variates known as adaptive velocity fields (AVF). It extends Pyro's standard MultivariateNormal distribution by introducing a learned control variate parameter that reduces the variance of gradient estimates during stochastic optimization.

The distribution maintains the same forward sampling semantics as a standard multivariate normal but overrides the reparameterized sampling (rsample) method with a custom autograd function (_AVFMVNSample). During the backward pass, this custom function applies infinitesimal rotation-based control variates modulated by the control_var parameter to yield lower-variance pathwise gradient estimates. The control variate parameter is a 2 x L x D tensor (where L is an arbitrary positive integer and D is the dimensionality of the distribution) that must be learned concurrently with the loc and scale_tril parameters.

The implementation uses a custom torch.autograd.Function subclass (_AVFMVNSample) that defines explicit forward and backward passes. The forward pass performs the standard reparameterization trick (sampling white noise and transforming via the Cholesky factor), while the backward pass computes modified gradients that include the control variate correction terms.

Usage

This distribution is used in variational inference settings where reducing the variance of gradient estimates is critical for stable and efficient optimization. It is particularly useful when working with multivariate normal variational families in stochastic variational inference (SVI). The control_var parameter should be initialized and optimized alongside the distribution parameters using a separate or joint optimizer.

Code Reference

Source Location

pyro/distributions/avf_mvn.py

Signature

class AVFMultivariateNormal(MultivariateNormal):
    def __init__(self, loc, scale_tril, control_var):
        ...

Import

from pyro.distributions import AVFMultivariateNormal

I/O Contract

Inputs

Parameter Type Description
loc torch.Tensor D-dimensional mean vector. Must be 1-dimensional.
scale_tril torch.Tensor Cholesky factor of the covariance matrix. Must be a D x D lower-triangular matrix.
control_var torch.Tensor A 2 x L x D tensor that parameterizes the control variate. L is an arbitrary positive integer. This parameter needs to be learned (adapted) to achieve lower variance gradients.

Outputs

Method Return Type Description
rsample(sample_shape) torch.Tensor Draws reparameterized samples from the distribution with adaptive velocity field control variates applied during the backward pass.
log_prob(value) torch.Tensor Evaluates the log probability density at the given value (inherited from MultivariateNormal).

Usage Examples

import torch
from pyro.distributions import AVFMultivariateNormal

D = 5
L = 3

loc = torch.zeros(D)
scale_tril = torch.eye(D)
control_var = torch.tensor(0.1 * torch.ones(2, L, D), requires_grad=True)

opt_cv = torch.optim.Adam([control_var], lr=0.1, betas=(0.5, 0.999))

for _ in range(1000):
    d = AVFMultivariateNormal(loc, scale_tril, control_var)
    z = d.rsample()
    cost = torch.pow(z, 2.0).sum()
    cost.backward()
    opt_cv.step()
    opt_cv.zero_grad()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment