Principle:Pyro ppl Pyro Amortized Variational Inference
Metadata
| Field | Value |
|---|---|
| Principle ID | Pyro_ppl_Pyro_Amortized_Variational_Inference |
| Title | Amortized Variational Inference |
| Project | Pyro (pyro-ppl/pyro) |
| Domains | Deep_Learning, Variational_Inference, Generative_Models |
| Implementation | Pyro_ppl_Pyro_VAE_Encoder_Decoder_Pattern |
| Repository | https://github.com/pyro-ppl/pyro |
Summary
Amortized Variational Inference is the principle of using neural networks to amortize (share) the cost of inference across data points. Instead of optimizing separate variational parameters for each observation, a single neural network (the encoder or inference network) is trained to map any input data point directly to the parameters of its variational distribution. This is the foundation of Variational Autoencoders (VAEs) and enables efficient inference on unseen data without re-optimization.
Motivation
In standard variational inference, each data point x_i requires its own set of variational parameters phi_i to define q(z_i | phi_i). For N data points, this means N * dim(phi) parameters to optimize. This approach has two critical limitations:
- Scalability: The number of variational parameters grows linearly with the dataset size, making it prohibitive for large datasets.
- Generalization: After training, performing inference on a new (unseen) data point requires running a new optimization procedure from scratch.
Amortized inference addresses both limitations by parameterizing the variational distribution as:
- q(z | x) = q(z | f_phi(x))
where f_phi is a neural network (encoder) with shared parameters phi. The encoder learns a mapping from data space to variational parameter space. Once trained, inference on any new data point is a single forward pass through the encoder.
Core Concepts
The Encoder-Decoder Framework
Amortized variational inference typically involves two neural networks:
- Encoder (Inference Network): Maps observed data x to the parameters of the variational distribution q(z | x). For example, in a VAE with a Gaussian variational family, the encoder outputs (z_loc, z_scale).
- Decoder (Generative Network): Maps latent codes z to the parameters of the observation likelihood p(x | z). For example, in an image VAE, the decoder outputs the parameters of a Bernoulli or Gaussian distribution over pixel values.
Relationship to Model and Guide
In Pyro's terminology:
| Component | Pyro Concept | Neural Network | Distribution |
|---|---|---|---|
| Generative model | model() |
Decoder | z) * p(z) |
| Variational distribution | guide() |
Encoder | x) |
The model defines the generative process (prior + likelihood), and the guide defines the approximate posterior. Both are Python callables that use pyro.sample and pyro.module.
The ELBO Objective
Training optimizes the Evidence Lower Bound (ELBO):
- ELBO = E_{q(z|x)}[log p(x|z)] - KL(q(z|x) || p(z))
The first term encourages accurate reconstruction (data fidelity), and the second term regularizes the variational distribution to stay close to the prior.
Registration with Pyro
Neural network modules used inside Pyro models and guides must be registered via pyro.module(). This ensures that:
- Their parameters are included in the Pyro parameter store
- Gradients flow correctly through the ELBO computation
- The parameters are managed by the SVI optimizer
def guide(x):
pyro.module("encoder", encoder_net)
z_loc, z_scale = encoder_net(x)
pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
How It Works
The amortized inference training loop:
- Forward pass through encoder: x -> (z_loc, z_scale)
- Sample latent code: z ~ Normal(z_loc, z_scale)
- Forward pass through decoder: z -> reconstruction parameters
- Compute ELBO: reconstruction_log_prob + KL_divergence
- Backpropagate and update: Both encoder and decoder parameters are updated jointly
After training:
- Encoding: Any new data point can be mapped to its approximate posterior in a single forward pass
- Decoding: Latent codes can be mapped to generated data
- Posterior predictive: Combine encoding and decoding for reconstruction and prediction
Example
import torch
import torch.nn as nn
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
class Encoder(nn.Module):
def __init__(self, z_dim, hidden_dim):
super().__init__()
self.fc1 = nn.Linear(784, hidden_dim)
self.fc21 = nn.Linear(hidden_dim, z_dim) # mean
self.fc22 = nn.Linear(hidden_dim, z_dim) # log-variance
self.softplus = nn.Softplus()
def forward(self, x):
x = x.reshape(-1, 784)
hidden = self.softplus(self.fc1(x))
z_loc = self.fc21(hidden)
z_scale = torch.exp(self.fc22(hidden))
return z_loc, z_scale
class Decoder(nn.Module):
def __init__(self, z_dim, hidden_dim):
super().__init__()
self.fc1 = nn.Linear(z_dim, hidden_dim)
self.fc21 = nn.Linear(hidden_dim, 784)
self.softplus = nn.Softplus()
def forward(self, z):
hidden = self.softplus(self.fc1(z))
loc_img = torch.sigmoid(self.fc21(hidden))
return loc_img
# Instantiate networks
encoder = Encoder(z_dim=50, hidden_dim=400)
decoder = Decoder(z_dim=50, hidden_dim=400)
# Generative model p(x, z) = p(x|z)p(z)
def model(x):
pyro.module("decoder", decoder)
with pyro.plate("data", x.shape[0]):
z_loc = torch.zeros(x.shape[0], 50, device=x.device)
z_scale = torch.ones(x.shape[0], 50, device=x.device)
z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
loc_img = decoder(z)
pyro.sample("obs", dist.Bernoulli(loc_img).to_event(1),
obs=x.reshape(-1, 784))
# Amortized guide q(z|x)
def guide(x):
pyro.module("encoder", encoder)
with pyro.plate("data", x.shape[0]):
z_loc, z_scale = encoder(x)
pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
# Train
svi = SVI(model, guide, Adam({"lr": 1e-3}), loss=Trace_ELBO())
Relationship to Other Principles
- Pyro_ppl_Pyro_Posterior_Predictive_Analysis -- After training an amortized model, the
Predictiveclass can generate posterior predictive samples by passing the trained guide (encoder-based) toPredictive. - Pyro_ppl_Pyro_Enumeration_Configuration -- In semi-supervised VAEs or discrete mixture VAEs, amortized inference can be combined with enumeration of discrete latent variables.
Related Pages
Implemented By
References
- Kingma, D.P. & Welling, M., "Auto-Encoding Variational Bayes", 2014. https://arxiv.org/abs/1312.6114
- Rezende, D.J., Mohamed, S. & Wierstra, D., "Stochastic Backpropagation and Approximate Inference in Deep Generative Models", 2014. https://arxiv.org/abs/1401.4082
- Pyro VAE tutorial: https://pyro.ai/examples/vae.html