Implementation:Pyro ppl Pyro VAE Encoder Decoder Pattern
Metadata
| Field | Value |
|---|---|
| Implementation ID | Pyro_ppl_Pyro_VAE_Encoder_Decoder_Pattern |
| Title | VAE Encoder-Decoder Pattern |
| Type | Pattern Doc (user-defined architecture, not a library API) |
| Project | Pyro (pyro-ppl/pyro) |
| Reference File | examples/vae/vae.py, Lines 1-256
|
| Implements | Pyro_ppl_Pyro_Amortized_Variational_Inference |
| Repository | https://github.com/pyro-ppl/pyro |
Summary
The VAE Encoder-Decoder Pattern describes the canonical architecture for implementing Variational Autoencoders in Pyro. This is a user-defined pattern (not a fixed library API) consisting of an Encoder(nn.Module) and Decoder(nn.Module) that are registered with Pyro via pyro.module() and used within model and guide functions. The reference implementation in examples/vae/vae.py demonstrates this pattern for MNIST digit generation.
Pattern Overview
The VAE pattern consists of four components:
- Encoder (
nn.Module): Maps data x to variational parameters (z_loc, z_scale) - Decoder (
nn.Module): Maps latent code z to reconstruction parameters - Model function: Defines the generative process p(x, z) = p(x|z)p(z) using the Decoder
- Guide function: Defines the variational distribution q(z|x) using the Encoder
Encoder Interface
class Encoder(nn.Module):
"""
Maps input data to variational parameters for the approximate posterior.
Input: x of shape (batch_size, data_dim)
Output: (z_loc, z_scale) each of shape (batch_size, z_dim)
"""
def __init__(self, z_dim, hidden_dim):
super().__init__()
# Define layers
self.fc1 = nn.Linear(784, hidden_dim)
self.fc21 = nn.Linear(hidden_dim, z_dim) # mean head
self.fc22 = nn.Linear(hidden_dim, z_dim) # log-std head
self.softplus = nn.Softplus()
def forward(self, x):
x = x.reshape(-1, 784)
hidden = self.softplus(self.fc1(x))
z_loc = self.fc21(hidden)
z_scale = torch.exp(self.fc22(hidden))
return z_loc, z_scale
Key Design Decisions
- Two output heads: One for the mean (
fc21) and one for the scale (fc22) of the variational Gaussian - Positive scale: The scale (standard deviation) is ensured to be positive via
torch.exp()applied to the raw network output - Softplus activation: Used instead of ReLU to maintain smooth gradients, which is important for the reparameterization trick
- Reshape input: MNIST images (28x28) are flattened to 784-dimensional vectors
Decoder Interface
class Decoder(nn.Module):
"""
Maps latent codes to reconstruction parameters for the observation model.
Input: z of shape (batch_size, z_dim)
Output: loc_img of shape (batch_size, data_dim)
"""
def __init__(self, z_dim, hidden_dim):
super().__init__()
self.fc1 = nn.Linear(z_dim, hidden_dim)
self.fc21 = nn.Linear(hidden_dim, 784)
self.softplus = nn.Softplus()
def forward(self, z):
hidden = self.softplus(self.fc1(z))
loc_img = torch.sigmoid(self.fc21(hidden))
return loc_img
Key Design Decisions
- Sigmoid output: Produces values in [0, 1] suitable for Bernoulli likelihood (binary pixel values)
- Architecture mirrors encoder: Symmetric structure with one hidden layer
- Output represents distribution parameters: The output is the Bernoulli probability for each pixel, not the pixel value itself
Registration with Pyro
Both networks must be registered with Pyro using pyro.module() to ensure their parameters are tracked:
# In the model function:
pyro.module("decoder", self.decoder)
# In the guide function:
pyro.module("encoder", self.encoder)
This registration:
- Adds the module's parameters to the Pyro parameter store
- Allows the SVI optimizer to update the network weights
- Enables serialization and loading of trained models
Model Function (Generative Process)
def model(self, x):
# Register decoder with Pyro
pyro.module("decoder", self.decoder)
with pyro.plate("data", x.shape[0]):
# Prior p(z) = Normal(0, I)
z_loc = torch.zeros(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device)
z_scale = torch.ones(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device)
z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
# Likelihood p(x|z) via decoder
loc_img = self.decoder(z)
pyro.sample(
"obs",
dist.Bernoulli(loc_img, validate_args=False).to_event(1),
obs=x.reshape(-1, 784),
)
return loc_img
Key Elements
pyro.plate("data", ...): Declares that data points are conditionally independent given z.to_event(1): Treats the z_dim dimensions as a single multivariate event (not independent samples)- Standard Normal prior: p(z) = Normal(0, I) -- the simplest prior choice
- Bernoulli likelihood: Appropriate for binary image data (MNIST)
Guide Function (Variational Distribution)
def guide(self, x):
# Register encoder with Pyro
pyro.module("encoder", self.encoder)
with pyro.plate("data", x.shape[0]):
# Use encoder to get variational parameters
z_loc, z_scale = self.encoder(x)
# Sample from variational distribution q(z|x)
pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
Key Elements
- Amortization: The encoder maps each x_i to its own (z_loc_i, z_scale_i) in a single forward pass
- Site name matching: The guide's
"latent"site must match the model's"latent"site - Same plate structure: The guide's
pyro.plate("data", ...)must match the model's
VAE Wrapper Class
The reference implementation wraps encoder, decoder, model, and guide in a single nn.Module:
class VAE(nn.Module):
def __init__(self, z_dim=50, hidden_dim=400, use_cuda=False):
super().__init__()
self.encoder = Encoder(z_dim, hidden_dim)
self.decoder = Decoder(z_dim, hidden_dim)
self.z_dim = z_dim
if use_cuda:
self.cuda()
def model(self, x):
... # uses self.decoder
def guide(self, x):
... # uses self.encoder
def reconstruct_img(self, x):
z_loc, z_scale = self.encoder(x)
z = dist.Normal(z_loc, z_scale).sample()
loc_img = self.decoder(z)
return loc_img
Configuration Parameters
| Parameter | Default | Description |
|---|---|---|
z_dim |
50 | Dimensionality of the latent space. Controls the capacity of the latent representation. |
hidden_dim |
400 | Number of hidden units in the encoder and decoder. Controls network capacity. |
data_dim |
784 | Dimensionality of the input data (28 * 28 for MNIST). |
use_cuda |
False | Whether to move model parameters to GPU. |
Training Loop
vae = VAE(z_dim=50, hidden_dim=400)
optimizer = Adam({"lr": 1e-3})
elbo = Trace_ELBO()
svi = SVI(vae.model, vae.guide, optimizer, loss=elbo)
for epoch in range(num_epochs):
epoch_loss = 0.0
for x, _ in train_loader:
epoch_loss += svi.step(x)
avg_loss = epoch_loss / len(train_loader.dataset)
print(f"Epoch {epoch}: avg loss = {avg_loss:.4f}")
Extending the Pattern
The Encoder-Decoder pattern is highly flexible. Common extensions include:
- Convolutional architectures: Replace linear layers with Conv2d/ConvTranspose2d for spatial data
- Different likelihoods: Use
dist.Normalfor continuous data,dist.Categoricalfor discrete data - Conditional VAE (CVAE): Pass labels or other conditioning variables to both encoder and decoder
- Hierarchical VAE: Stack multiple encoder-decoder pairs for multi-scale latent representations
- Beta-VAE: Scale the KL term by a factor beta to control disentanglement
Related Pages
Implements Principle
Related Implementations
- Pyro_ppl_Pyro_Predictive_SVI -- Use
Predictive(vae.model, guide=vae.guide, num_samples=N)to generate posterior predictive samples from a trained VAE. - Pyro_ppl_Pyro_Config_Enumerate -- For VAEs with discrete latent variables (e.g., discrete mixture components), enumeration can be used to marginalize out discrete choices.
- Environment:Pyro_ppl_Pyro_CUDA_GPU_Acceleration