Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Workflow:Pyro ppl Pyro SVI Training

From Leeroopedia
Revision as of 11:01, 16 February 2026 by Admin (talk | contribs) (Auto-imported from workflows/Pyro_ppl_Pyro_SVI_Training.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Probabilistic_Programming, Variational_Inference, Bayesian_Modeling
Last Updated 2026-02-09 09:00 GMT

Overview

End-to-end process for performing Stochastic Variational Inference (SVI) in Pyro, the primary optimization-based inference method for fitting probabilistic models to data.

Description

This workflow describes the standard procedure for approximate Bayesian inference using Pyro's SVI engine. The process takes a probabilistic model (defining the joint distribution over latent variables and observations) and a guide (the variational approximation), then optimizes the guide parameters to minimize the KL divergence from the true posterior. It covers defining model and guide functions using Pyro's primitives, selecting an appropriate ELBO objective, configuring the optimizer, running the training loop, and extracting posterior estimates. This is the most fundamental inference pattern in Pyro and serves as the foundation for more advanced techniques.

Usage

Execute this workflow when you have a probabilistic model with continuous latent variables and want to perform fast, scalable approximate posterior inference. This is appropriate when exact inference is intractable, MCMC is too slow for the data size, or you need gradient-based optimization of model parameters alongside inference. Typical use cases include fitting hierarchical Bayesian models, learning parameters of latent variable models, and any setting where you need a parametric approximation to the posterior.

Execution Steps

Step 1: Define the Probabilistic Model

Write a Python function (or callable) that specifies the joint distribution over latent variables and observed data using Pyro's probabilistic primitives. The model function uses pyro.sample to declare random variables with their prior distributions, pyro.plate to declare conditional independence structure for vectorized computation and data subsampling, and pyro.param for any learnable non-Bayesian parameters. Observed data is incorporated by passing the obs keyword argument to sample statements.

Key considerations:

  • Each sample site must have a unique string name
  • Use pyro.plate for data-parallel dimensions to enable minibatching
  • Observation distributions should match the data likelihood
  • Model parameters can be registered via pyro.module for neural network components

Step 2: Define the Variational Guide

Write a guide function that specifies the variational family — the parametric approximation to the posterior. The guide must contain a matching sample statement for every unobserved sample site in the model (same name, same shape). Guide parameters are declared using pyro.param with appropriate constraints. Alternatively, use one of Pyro's AutoGuide classes (AutoNormal, AutoMultivariateNormal, AutoDelta, AutoDiagonalNormal) that automatically construct a guide from the model structure.

Key considerations:

  • Guide sample sites must exactly match model latent variable names
  • Use constraints (e.g., constraints.positive) on parameters like scale
  • AutoGuide classes eliminate the need for manual guide construction
  • For complex posteriors, consider normalizing flow-based guides (AutoNormalizingFlow)

Step 3: Select the ELBO Objective

Choose the appropriate Evidence Lower Bound (ELBO) estimator for the model characteristics. Trace_ELBO is the standard single-sample estimator suitable for most continuous models. TraceMeanField_ELBO uses analytic KL terms where available for lower variance. TraceGraph_ELBO exploits conditional independence for Rao-Blackwellized gradient estimates. JIT variants (prefixed with Jit) compile the loss for faster execution.

Key considerations:

  • Trace_ELBO is the default and works for all continuous models
  • TraceMeanField_ELBO reduces variance when analytic KL divergences are available
  • TraceGraph_ELBO is beneficial for models with complex dependency structure
  • JIT compilation provides speedup but requires compatible model structure
  • num_particles controls the number of samples for gradient estimation

Step 4: Configure the Optimizer

Wrap a PyTorch optimizer with Pyro's optimizer interface. Pyro provides wrappers for all standard PyTorch optimizers (Adam, SGD, etc.) plus specialized variants like ClippedAdam (with gradient clipping) and AdagradRMSProp. The optimizer manages separate parameter states for model and guide parameters. Optionally configure a learning rate scheduler for annealing.

Key considerations:

  • Adam with learning rate 1e-3 to 1e-2 is a common starting point
  • ClippedAdam provides gradient clipping for stability
  • Learning rate schedulers can improve convergence
  • Pyro optimizers handle per-parameter state automatically

Step 5: Initialize SVI and Clear Parameter Store

Create the SVI object by passing the model, guide, optimizer, and loss function. Before training, call pyro.clear_param_store() to reset all learned parameters. This ensures a clean start and prevents contamination from previous runs.

Key considerations:

  • Always clear the param store before a new training run
  • SVI object is stateful — it maintains optimizer state across steps
  • For reproducibility, set random seeds before initialization

Step 6: Run the Training Loop

Iterate over the dataset, calling svi.step() on each minibatch. Each call computes the ELBO loss, backpropagates gradients, and updates parameters. Track the loss over time to monitor convergence. Periodically evaluate on held-out data using svi.evaluate_loss() (which computes loss without gradient updates).

Key considerations:

  • svi.step() returns the loss value for the current minibatch
  • Normalize loss by dataset size for comparable metrics across batch sizes
  • Monitor for divergence (loss increasing or NaN values)
  • Use evaluate_loss() for validation without gradient computation
  • Training typically requires hundreds to thousands of iterations

Step 7: Extract Posterior Estimates

After training, extract the learned variational parameters from the param store or use the Predictive utility to generate posterior samples. Parameters can be accessed via pyro.param() calls or pyro.get_param_store(). For predictions, construct a Predictive object with the trained model and guide to generate posterior predictive samples.

Key considerations:

  • Guide parameters approximate the posterior distribution
  • Use Predictive for structured posterior predictive sampling
  • Parameters can be saved/loaded via param store serialization
  • Point estimates are available as guide parameter means

Execution Diagram

GitHub URL

Workflow Repository