Principle:Pyro ppl Pyro Gradient Estimation
| Knowledge Sources | |
|---|---|
| Domains | Variational Inference, Gradient Estimation, Discrete Optimization |
| Last Updated | 2026-02-09 09:00 GMT |
Overview
Gradient estimators for non-reparameterizable and discrete distributions enable stochastic optimization of variational objectives when standard backpropagation through sampling is not possible.
Description
Variational inference optimizes the ELBO objective, which requires computing gradients of expectations with respect to distribution parameters. For reparameterizable distributions (e.g., Gaussian), the reparameterization trick rewrites the sample as a differentiable function of parameters and noise, enabling standard backpropagation. However, many distributions -- particularly discrete ones -- are not reparameterizable.
Several strategies address this challenge:
Relaxed Straight-Through Estimator: For discrete random variables, this approach uses a continuous relaxation (e.g., Gumbel-Softmax) during the backward pass while using discrete samples during the forward pass. The "straight-through" part means the discrete sample is used for computing the function value, but gradients flow through the continuous relaxation. This combines low bias (from using actual discrete samples) with the ability to compute gradients.
Analytical Variance Functions (AVF): For certain distributions, it is possible to compute the gradient of the entropy or KL divergence analytically, even when the full gradient of the ELBO requires sampling. The AVF approach computes as much of the gradient analytically as possible, using Monte Carlo estimation only for the remaining terms, thereby reducing variance.
Optimal Mass Transport (OMT) for Multivariate Normal: The OMT approach constructs a coupling between the variational distribution and the prior that minimizes the expected squared distance (Wasserstein-2 distance). For multivariate normals, this coupling can be computed in closed form and used to construct low-variance gradient estimators. The OMT gradient estimator uses the optimal transport map between two Gaussians rather than independent samples from each.
Usage
Use these gradient estimators when:
- Working with discrete latent variables where the reparameterization trick is inapplicable.
- Standard score-function (REINFORCE) gradients have too high variance.
- You need to balance bias and variance in gradient estimates for optimization.
- Using multivariate normal variational families and wanting lower-variance gradient estimates.
- Building custom inference algorithms that require specialized gradient computation.
Theoretical Basis
Score function (REINFORCE) estimator (baseline approach):
# Goal: compute grad_phi E_{q(z|phi)}[f(z)]
# Score function estimator:
# grad_phi E[f(z)] = E[f(z) * grad_phi log q(z|phi)]
# High variance, unbiased
Gumbel-Softmax / Concrete relaxation:
# For a Categorical distribution with logits alpha:
# Gumbel trick: z = argmax_k (alpha_k + g_k) where g_k ~ Gumbel(0,1)
# Continuous relaxation (temperature tau):
# y_k = exp((alpha_k + g_k) / tau) / sum_j exp((alpha_j + g_j) / tau)
# As tau -> 0: y approaches one-hot (discrete)
# As tau -> inf: y approaches uniform
# Straight-through variant:
# Forward: z_hard = one_hot(argmax(y))
# Backward: gradients flow through y (the soft version)
# grad = grad_phi y (biased but low variance)
AVF for multivariate normal:
# For q(z) = Normal(mu, Sigma):
# ELBO = E_q[log p(x,z)] + H[q]
# Entropy term has analytical gradient:
# H[q] = 0.5 * log |Sigma| + const
# grad_mu H = 0
# grad_Sigma H = 0.5 * Sigma^{-1}
# Only the E_q[log p(x,z)] term needs MC estimation
# This decomposition reduces variance
Optimal Mass Transport gradient:
# For q = N(mu_q, Sigma_q) and p = N(mu_p, Sigma_p):
# The OMT map T: q -> p is:
# T(z) = mu_p + A * (z - mu_q)
# where A = Sigma_q^{-1/2} * (Sigma_q^{1/2} Sigma_p Sigma_q^{1/2})^{1/2} * Sigma_q^{-1/2}
# Instead of sampling z_q ~ q and z_p ~ p independently:
# Sample z_q ~ q, then z_p = T(z_q)
# The coupling (z_q, T(z_q)) minimizes E[|z_q - z_p|^2]
# This reduces variance in gradient estimates that depend on both distributions
# Wasserstein-2 distance (closed form for Gaussians):
# W_2^2(q, p) = |mu_q - mu_p|^2 + trace(Sigma_q + Sigma_p - 2*(Sigma_q^{1/2} Sigma_p Sigma_q^{1/2})^{1/2})