Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Pyro ppl Pyro VAE Encoder Decoder Pattern

From Leeroopedia


Metadata

Field Value
Implementation ID Pyro_ppl_Pyro_VAE_Encoder_Decoder_Pattern
Title VAE Encoder-Decoder Pattern
Type Pattern Doc (user-defined architecture, not a library API)
Project Pyro (pyro-ppl/pyro)
Reference File examples/vae/vae.py, Lines 1-256
Implements Pyro_ppl_Pyro_Amortized_Variational_Inference
Repository https://github.com/pyro-ppl/pyro

Summary

The VAE Encoder-Decoder Pattern describes the canonical architecture for implementing Variational Autoencoders in Pyro. This is a user-defined pattern (not a fixed library API) consisting of an Encoder(nn.Module) and Decoder(nn.Module) that are registered with Pyro via pyro.module() and used within model and guide functions. The reference implementation in examples/vae/vae.py demonstrates this pattern for MNIST digit generation.

Pattern Overview

The VAE pattern consists of four components:

  1. Encoder (nn.Module): Maps data x to variational parameters (z_loc, z_scale)
  2. Decoder (nn.Module): Maps latent code z to reconstruction parameters
  3. Model function: Defines the generative process p(x, z) = p(x|z)p(z) using the Decoder
  4. Guide function: Defines the variational distribution q(z|x) using the Encoder

Encoder Interface

class Encoder(nn.Module):
    """
    Maps input data to variational parameters for the approximate posterior.

    Input: x of shape (batch_size, data_dim)
    Output: (z_loc, z_scale) each of shape (batch_size, z_dim)
    """
    def __init__(self, z_dim, hidden_dim):
        super().__init__()
        # Define layers
        self.fc1 = nn.Linear(784, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)   # mean head
        self.fc22 = nn.Linear(hidden_dim, z_dim)   # log-std head
        self.softplus = nn.Softplus()

    def forward(self, x):
        x = x.reshape(-1, 784)
        hidden = self.softplus(self.fc1(x))
        z_loc = self.fc21(hidden)
        z_scale = torch.exp(self.fc22(hidden))
        return z_loc, z_scale

Key Design Decisions

  • Two output heads: One for the mean (fc21) and one for the scale (fc22) of the variational Gaussian
  • Positive scale: The scale (standard deviation) is ensured to be positive via torch.exp() applied to the raw network output
  • Softplus activation: Used instead of ReLU to maintain smooth gradients, which is important for the reparameterization trick
  • Reshape input: MNIST images (28x28) are flattened to 784-dimensional vectors

Decoder Interface

class Decoder(nn.Module):
    """
    Maps latent codes to reconstruction parameters for the observation model.

    Input: z of shape (batch_size, z_dim)
    Output: loc_img of shape (batch_size, data_dim)
    """
    def __init__(self, z_dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, 784)
        self.softplus = nn.Softplus()

    def forward(self, z):
        hidden = self.softplus(self.fc1(z))
        loc_img = torch.sigmoid(self.fc21(hidden))
        return loc_img

Key Design Decisions

  • Sigmoid output: Produces values in [0, 1] suitable for Bernoulli likelihood (binary pixel values)
  • Architecture mirrors encoder: Symmetric structure with one hidden layer
  • Output represents distribution parameters: The output is the Bernoulli probability for each pixel, not the pixel value itself

Registration with Pyro

Both networks must be registered with Pyro using pyro.module() to ensure their parameters are tracked:

# In the model function:
pyro.module("decoder", self.decoder)

# In the guide function:
pyro.module("encoder", self.encoder)

This registration:

  • Adds the module's parameters to the Pyro parameter store
  • Allows the SVI optimizer to update the network weights
  • Enables serialization and loading of trained models

Model Function (Generative Process)

def model(self, x):
    # Register decoder with Pyro
    pyro.module("decoder", self.decoder)
    with pyro.plate("data", x.shape[0]):
        # Prior p(z) = Normal(0, I)
        z_loc = torch.zeros(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device)
        z_scale = torch.ones(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device)
        z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
        # Likelihood p(x|z) via decoder
        loc_img = self.decoder(z)
        pyro.sample(
            "obs",
            dist.Bernoulli(loc_img, validate_args=False).to_event(1),
            obs=x.reshape(-1, 784),
        )
        return loc_img

Key Elements

  • pyro.plate("data", ...): Declares that data points are conditionally independent given z
  • .to_event(1): Treats the z_dim dimensions as a single multivariate event (not independent samples)
  • Standard Normal prior: p(z) = Normal(0, I) -- the simplest prior choice
  • Bernoulli likelihood: Appropriate for binary image data (MNIST)

Guide Function (Variational Distribution)

def guide(self, x):
    # Register encoder with Pyro
    pyro.module("encoder", self.encoder)
    with pyro.plate("data", x.shape[0]):
        # Use encoder to get variational parameters
        z_loc, z_scale = self.encoder(x)
        # Sample from variational distribution q(z|x)
        pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))

Key Elements

  • Amortization: The encoder maps each x_i to its own (z_loc_i, z_scale_i) in a single forward pass
  • Site name matching: The guide's "latent" site must match the model's "latent" site
  • Same plate structure: The guide's pyro.plate("data", ...) must match the model's

VAE Wrapper Class

The reference implementation wraps encoder, decoder, model, and guide in a single nn.Module:

class VAE(nn.Module):
    def __init__(self, z_dim=50, hidden_dim=400, use_cuda=False):
        super().__init__()
        self.encoder = Encoder(z_dim, hidden_dim)
        self.decoder = Decoder(z_dim, hidden_dim)
        self.z_dim = z_dim
        if use_cuda:
            self.cuda()

    def model(self, x):
        ...  # uses self.decoder

    def guide(self, x):
        ...  # uses self.encoder

    def reconstruct_img(self, x):
        z_loc, z_scale = self.encoder(x)
        z = dist.Normal(z_loc, z_scale).sample()
        loc_img = self.decoder(z)
        return loc_img

Configuration Parameters

Parameter Default Description
z_dim 50 Dimensionality of the latent space. Controls the capacity of the latent representation.
hidden_dim 400 Number of hidden units in the encoder and decoder. Controls network capacity.
data_dim 784 Dimensionality of the input data (28 * 28 for MNIST).
use_cuda False Whether to move model parameters to GPU.

Training Loop

vae = VAE(z_dim=50, hidden_dim=400)
optimizer = Adam({"lr": 1e-3})
elbo = Trace_ELBO()
svi = SVI(vae.model, vae.guide, optimizer, loss=elbo)

for epoch in range(num_epochs):
    epoch_loss = 0.0
    for x, _ in train_loader:
        epoch_loss += svi.step(x)
    avg_loss = epoch_loss / len(train_loader.dataset)
    print(f"Epoch {epoch}: avg loss = {avg_loss:.4f}")

Extending the Pattern

The Encoder-Decoder pattern is highly flexible. Common extensions include:

  • Convolutional architectures: Replace linear layers with Conv2d/ConvTranspose2d for spatial data
  • Different likelihoods: Use dist.Normal for continuous data, dist.Categorical for discrete data
  • Conditional VAE (CVAE): Pass labels or other conditioning variables to both encoder and decoder
  • Hierarchical VAE: Stack multiple encoder-decoder pairs for multi-scale latent representations
  • Beta-VAE: Scale the KL term by a factor beta to control disentanglement

Related Pages

Implements Principle

Related Implementations

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment