Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Workflow:Pyro ppl Pyro VAE Training

From Leeroopedia


Knowledge Sources
Domains Probabilistic_Programming, Deep_Generative_Models, Variational_Autoencoders
Last Updated 2026-02-09 09:00 GMT

Overview

End-to-end process for training a Variational Autoencoder (VAE) in Pyro, combining deep neural networks with probabilistic inference for learning latent representations of data.

Description

This workflow describes the standard procedure for building and training a VAE using Pyro's probabilistic programming framework. A VAE consists of a generative model (decoder) that maps latent variables to observations and an inference network (encoder) that amortizes posterior inference by mapping observations to approximate posterior distributions over latent variables. Pyro's framework naturally accommodates this by treating the decoder as the model and the encoder as the guide. The process covers defining encoder and decoder architectures, registering neural network modules with Pyro, constructing the ELBO objective, training with minibatched data, and evaluating the learned latent representations.

Usage

Execute this workflow when you need to learn a compressed latent representation of high-dimensional data (images, text, sequences) while maintaining a probabilistic generative model. This is appropriate for unsupervised learning tasks including representation learning, data generation, anomaly detection, and semi-supervised classification. The VAE framework is especially useful when you need both a generative model and an inference mechanism, and when the latent space should have a meaningful structure.

Execution Steps

Step 1: Define Encoder and Decoder Networks

Build PyTorch neural network modules for the encoder (recognition network) and decoder (generative network). The encoder takes observations and outputs the parameters of the variational posterior (typically mean and variance of a Gaussian). The decoder takes a latent code and outputs the parameters of the observation distribution (e.g., pixel probabilities for Bernoulli observations). Both are standard torch.nn.Module subclasses.

Key considerations:

  • Encoder outputs variational parameters (loc and scale for Normal posterior)
  • Decoder outputs observation likelihood parameters
  • Network architecture depends on data type (MLP for tabular, CNN for images)
  • Use softplus or exp for positive-valued outputs like scale parameters

Step 2: Define the Generative Model

Write the model function that specifies the generative process. Register the decoder network with Pyro using pyro.module so its parameters are tracked. Sample latent variables from the prior distribution (typically a standard Normal), pass them through the decoder to get observation distribution parameters, and score the observed data against this distribution. Use pyro.plate to vectorize over the data batch dimension.

Key considerations:

  • Register neural networks with pyro.module for parameter tracking
  • The prior is typically a unit Normal: z ~ N(0, I)
  • Observation model depends on data type (Bernoulli for binary, Normal for continuous)
  • pyro.plate enables minibatch training with correct ELBO scaling

Step 3: Define the Inference Guide

Write the guide function that specifies the amortized variational posterior. Register the encoder network with pyro.module. For each data point, pass the observation through the encoder to get variational parameters, then sample from the variational posterior. The guide must produce samples for every latent variable in the model with matching site names.

Key considerations:

  • Guide sample site names must exactly match model latent variable names
  • Amortization means the encoder directly predicts posterior parameters per data point
  • This eliminates the need for per-data-point variational parameters
  • The reparameterization trick (automatic in Pyro for Normal distributions) enables backpropagation through sampling

Step 4: Configure Training Infrastructure

Set up the ELBO loss (Trace_ELBO or JitTrace_ELBO), optimizer (typically Adam), and SVI instance. Create PyTorch DataLoaders for efficient minibatch iteration over the training and test datasets. Clear the parameter store before training.

Key considerations:

  • Trace_ELBO is the standard choice; JitTrace_ELBO compiles for speed
  • Adam optimizer with learning rate around 1e-3 is typical
  • DataLoader handles minibatching and shuffling
  • Clear param store for clean initialization

Step 5: Run the Training Loop

Iterate over epochs and minibatches, calling svi.step() on each batch. Track the training ELBO loss per epoch. Periodically evaluate on test data using svi.evaluate_loss() to monitor generalization. The ELBO loss naturally decomposes into a reconstruction term (how well the decoder reproduces the input) and a KL divergence term (how close the posterior is to the prior).

Key considerations:

  • Normalize loss by number of data points for interpretable metrics
  • Monitor both train and test ELBO for overfitting detection
  • Training typically requires tens to hundreds of epochs
  • Loss should decrease monotonically in early training

Step 6: Evaluate Latent Representations

After training, use the encoder to map data to the latent space and evaluate the learned representations. Visualize the latent space using dimensionality reduction (t-SNE or PCA). Generate new data by sampling from the prior and decoding. Perform reconstruction quality assessment by encoding then decoding held-out data.

Key considerations:

  • Latent space should show meaningful clustering for labeled data
  • Samples from the prior decoded through the model should produce realistic data
  • Reconstruction quality indicates how much information is captured
  • The Predictive class can be used for structured posterior predictive generation

Execution Diagram

GitHub URL

Workflow Repository