Workflow:Pyro ppl Pyro SVI Training
| Knowledge Sources | |
|---|---|
| Domains | Probabilistic_Programming, Variational_Inference, Bayesian_Modeling |
| Last Updated | 2026-02-09 09:00 GMT |
Overview
End-to-end process for performing Stochastic Variational Inference (SVI) in Pyro, the primary optimization-based inference method for fitting probabilistic models to data.
Description
This workflow describes the standard procedure for approximate Bayesian inference using Pyro's SVI engine. The process takes a probabilistic model (defining the joint distribution over latent variables and observations) and a guide (the variational approximation), then optimizes the guide parameters to minimize the KL divergence from the true posterior. It covers defining model and guide functions using Pyro's primitives, selecting an appropriate ELBO objective, configuring the optimizer, running the training loop, and extracting posterior estimates. This is the most fundamental inference pattern in Pyro and serves as the foundation for more advanced techniques.
Usage
Execute this workflow when you have a probabilistic model with continuous latent variables and want to perform fast, scalable approximate posterior inference. This is appropriate when exact inference is intractable, MCMC is too slow for the data size, or you need gradient-based optimization of model parameters alongside inference. Typical use cases include fitting hierarchical Bayesian models, learning parameters of latent variable models, and any setting where you need a parametric approximation to the posterior.
Execution Steps
Step 1: Define the Probabilistic Model
Write a Python function (or callable) that specifies the joint distribution over latent variables and observed data using Pyro's probabilistic primitives. The model function uses pyro.sample to declare random variables with their prior distributions, pyro.plate to declare conditional independence structure for vectorized computation and data subsampling, and pyro.param for any learnable non-Bayesian parameters. Observed data is incorporated by passing the obs keyword argument to sample statements.
Key considerations:
- Each sample site must have a unique string name
- Use pyro.plate for data-parallel dimensions to enable minibatching
- Observation distributions should match the data likelihood
- Model parameters can be registered via pyro.module for neural network components
Step 2: Define the Variational Guide
Write a guide function that specifies the variational family — the parametric approximation to the posterior. The guide must contain a matching sample statement for every unobserved sample site in the model (same name, same shape). Guide parameters are declared using pyro.param with appropriate constraints. Alternatively, use one of Pyro's AutoGuide classes (AutoNormal, AutoMultivariateNormal, AutoDelta, AutoDiagonalNormal) that automatically construct a guide from the model structure.
Key considerations:
- Guide sample sites must exactly match model latent variable names
- Use constraints (e.g., constraints.positive) on parameters like scale
- AutoGuide classes eliminate the need for manual guide construction
- For complex posteriors, consider normalizing flow-based guides (AutoNormalizingFlow)
Step 3: Select the ELBO Objective
Choose the appropriate Evidence Lower Bound (ELBO) estimator for the model characteristics. Trace_ELBO is the standard single-sample estimator suitable for most continuous models. TraceMeanField_ELBO uses analytic KL terms where available for lower variance. TraceGraph_ELBO exploits conditional independence for Rao-Blackwellized gradient estimates. JIT variants (prefixed with Jit) compile the loss for faster execution.
Key considerations:
- Trace_ELBO is the default and works for all continuous models
- TraceMeanField_ELBO reduces variance when analytic KL divergences are available
- TraceGraph_ELBO is beneficial for models with complex dependency structure
- JIT compilation provides speedup but requires compatible model structure
- num_particles controls the number of samples for gradient estimation
Step 4: Configure the Optimizer
Wrap a PyTorch optimizer with Pyro's optimizer interface. Pyro provides wrappers for all standard PyTorch optimizers (Adam, SGD, etc.) plus specialized variants like ClippedAdam (with gradient clipping) and AdagradRMSProp. The optimizer manages separate parameter states for model and guide parameters. Optionally configure a learning rate scheduler for annealing.
Key considerations:
- Adam with learning rate 1e-3 to 1e-2 is a common starting point
- ClippedAdam provides gradient clipping for stability
- Learning rate schedulers can improve convergence
- Pyro optimizers handle per-parameter state automatically
Step 5: Initialize SVI and Clear Parameter Store
Create the SVI object by passing the model, guide, optimizer, and loss function. Before training, call pyro.clear_param_store() to reset all learned parameters. This ensures a clean start and prevents contamination from previous runs.
Key considerations:
- Always clear the param store before a new training run
- SVI object is stateful — it maintains optimizer state across steps
- For reproducibility, set random seeds before initialization
Step 6: Run the Training Loop
Iterate over the dataset, calling svi.step() on each minibatch. Each call computes the ELBO loss, backpropagates gradients, and updates parameters. Track the loss over time to monitor convergence. Periodically evaluate on held-out data using svi.evaluate_loss() (which computes loss without gradient updates).
Key considerations:
- svi.step() returns the loss value for the current minibatch
- Normalize loss by dataset size for comparable metrics across batch sizes
- Monitor for divergence (loss increasing or NaN values)
- Use evaluate_loss() for validation without gradient computation
- Training typically requires hundreds to thousands of iterations
Step 7: Extract Posterior Estimates
After training, extract the learned variational parameters from the param store or use the Predictive utility to generate posterior samples. Parameters can be accessed via pyro.param() calls or pyro.get_param_store(). For predictions, construct a Predictive object with the trained model and guide to generate posterior predictive samples.
Key considerations:
- Guide parameters approximate the posterior distribution
- Use Predictive for structured posterior predictive sampling
- Parameters can be saved/loaded via param store serialization
- Point estimates are available as guide parameter means