Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Pyro ppl Pyro Amortized Variational Inference

From Leeroopedia


Metadata

Field Value
Principle ID Pyro_ppl_Pyro_Amortized_Variational_Inference
Title Amortized Variational Inference
Project Pyro (pyro-ppl/pyro)
Domains Deep_Learning, Variational_Inference, Generative_Models
Implementation Pyro_ppl_Pyro_VAE_Encoder_Decoder_Pattern
Repository https://github.com/pyro-ppl/pyro

Summary

Amortized Variational Inference is the principle of using neural networks to amortize (share) the cost of inference across data points. Instead of optimizing separate variational parameters for each observation, a single neural network (the encoder or inference network) is trained to map any input data point directly to the parameters of its variational distribution. This is the foundation of Variational Autoencoders (VAEs) and enables efficient inference on unseen data without re-optimization.

Motivation

In standard variational inference, each data point x_i requires its own set of variational parameters phi_i to define q(z_i | phi_i). For N data points, this means N * dim(phi) parameters to optimize. This approach has two critical limitations:

  1. Scalability: The number of variational parameters grows linearly with the dataset size, making it prohibitive for large datasets.
  2. Generalization: After training, performing inference on a new (unseen) data point requires running a new optimization procedure from scratch.

Amortized inference addresses both limitations by parameterizing the variational distribution as:

q(z | x) = q(z | f_phi(x))

where f_phi is a neural network (encoder) with shared parameters phi. The encoder learns a mapping from data space to variational parameter space. Once trained, inference on any new data point is a single forward pass through the encoder.

Core Concepts

The Encoder-Decoder Framework

Amortized variational inference typically involves two neural networks:

  • Encoder (Inference Network): Maps observed data x to the parameters of the variational distribution q(z | x). For example, in a VAE with a Gaussian variational family, the encoder outputs (z_loc, z_scale).
  • Decoder (Generative Network): Maps latent codes z to the parameters of the observation likelihood p(x | z). For example, in an image VAE, the decoder outputs the parameters of a Bernoulli or Gaussian distribution over pixel values.

Relationship to Model and Guide

In Pyro's terminology:

Component Pyro Concept Neural Network Distribution
Generative model model() Decoder z) * p(z)
Variational distribution guide() Encoder x)

The model defines the generative process (prior + likelihood), and the guide defines the approximate posterior. Both are Python callables that use pyro.sample and pyro.module.

The ELBO Objective

Training optimizes the Evidence Lower Bound (ELBO):

ELBO = E_{q(z|x)}[log p(x|z)] - KL(q(z|x) || p(z))

The first term encourages accurate reconstruction (data fidelity), and the second term regularizes the variational distribution to stay close to the prior.

Registration with Pyro

Neural network modules used inside Pyro models and guides must be registered via pyro.module(). This ensures that:

  • Their parameters are included in the Pyro parameter store
  • Gradients flow correctly through the ELBO computation
  • The parameters are managed by the SVI optimizer
def guide(x):
    pyro.module("encoder", encoder_net)
    z_loc, z_scale = encoder_net(x)
    pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))

How It Works

The amortized inference training loop:

  1. Forward pass through encoder: x -> (z_loc, z_scale)
  2. Sample latent code: z ~ Normal(z_loc, z_scale)
  3. Forward pass through decoder: z -> reconstruction parameters
  4. Compute ELBO: reconstruction_log_prob + KL_divergence
  5. Backpropagate and update: Both encoder and decoder parameters are updated jointly

After training:

  • Encoding: Any new data point can be mapped to its approximate posterior in a single forward pass
  • Decoding: Latent codes can be mapped to generated data
  • Posterior predictive: Combine encoding and decoding for reconstruction and prediction

Example

import torch
import torch.nn as nn
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

class Encoder(nn.Module):
    def __init__(self, z_dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(784, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)  # mean
        self.fc22 = nn.Linear(hidden_dim, z_dim)  # log-variance
        self.softplus = nn.Softplus()

    def forward(self, x):
        x = x.reshape(-1, 784)
        hidden = self.softplus(self.fc1(x))
        z_loc = self.fc21(hidden)
        z_scale = torch.exp(self.fc22(hidden))
        return z_loc, z_scale

class Decoder(nn.Module):
    def __init__(self, z_dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, 784)
        self.softplus = nn.Softplus()

    def forward(self, z):
        hidden = self.softplus(self.fc1(z))
        loc_img = torch.sigmoid(self.fc21(hidden))
        return loc_img

# Instantiate networks
encoder = Encoder(z_dim=50, hidden_dim=400)
decoder = Decoder(z_dim=50, hidden_dim=400)

# Generative model p(x, z) = p(x|z)p(z)
def model(x):
    pyro.module("decoder", decoder)
    with pyro.plate("data", x.shape[0]):
        z_loc = torch.zeros(x.shape[0], 50, device=x.device)
        z_scale = torch.ones(x.shape[0], 50, device=x.device)
        z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
        loc_img = decoder(z)
        pyro.sample("obs", dist.Bernoulli(loc_img).to_event(1),
                     obs=x.reshape(-1, 784))

# Amortized guide q(z|x)
def guide(x):
    pyro.module("encoder", encoder)
    with pyro.plate("data", x.shape[0]):
        z_loc, z_scale = encoder(x)
        pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))

# Train
svi = SVI(model, guide, Adam({"lr": 1e-3}), loss=Trace_ELBO())

Relationship to Other Principles

Related Pages

Implemented By

References

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment