Principle:Pyro ppl Pyro SVI Training Loop
Metadata
| Field | Value |
|---|---|
| Page Type | Principle |
| Knowledge Sources | Repo (Pyro) |
| Domains | Variational_Inference, Optimization |
| Last Updated | 2026-02-09 12:00 GMT |
Overview
The iterative training loop pattern for Stochastic Variational Inference, where each iteration computes the ELBO, backpropagates gradients, and updates variational parameters through the step() method, with optional loss-only evaluation via evaluate_loss() for convergence monitoring.
Description
The SVI training loop is the standard pattern for performing variational inference in Pyro. It centers on two complementary operations provided by the SVI class: step() for training and evaluate_loss() for monitoring.
The step() Method: Training
The step() method is the workhorse of the SVI training loop. Each call performs a complete optimization iteration that both computes the ELBO loss and applies parameter updates in a single call. Internally, it executes the following sequence:
- Trace parameter sites: Uses
poutine.trace(param_only=True)to discover all parameter sites by running the loss computation. This captures everypyro.paramsite accessed during model and guide execution. - Compute loss and gradients: Calls
self.loss_and_grads(model, guide, *args, **kwargs), which estimates the ELBO and backpropagates gradients through all registered parameters in a single operation. - Collect unconstrained parameters: Gathers all parameter values from the Pyro parameter store in their unconstrained (transformed) form, which is the space in which optimization occurs.
- Apply optimizer: Calls
self.optim(params), which applies the configured PyTorch optimizer (e.g., Adam) to update parameter values based on accumulated gradients. - Zero gradients: Calls
zero_grads(params)to reset all gradients to zero, preparing for the next iteration.
The method returns the scalar loss estimate as a float, enabling the caller to track training progress.
The evaluate_loss() Method: Monitoring
The evaluate_loss() method computes the ELBO loss without computing or applying gradients. It wraps the loss computation in a torch.no_grad() context, which:
- Avoids the computational cost of building the autograd graph.
- Prevents any parameter updates from occurring.
- Provides a clean estimate of the current ELBO for monitoring purposes.
This method is essential for convergence monitoring and validation, as it allows measuring the ELBO on held-out data without affecting the training state.
Training Loop Patterns
The SVI training loop typically follows one of these patterns:
Basic Loop
The simplest pattern iterates for a fixed number of steps, calling step() each iteration:
- Initialize the parameter store with
pyro.clear_param_store(). - Construct the SVI object with model, guide, optimizer, and loss.
- Loop for N iterations, calling
svi.step(data)and optionally logging the returned loss.
Loop with Validation
A more robust pattern monitors convergence using evaluate_loss():
- Periodically compute
svi.evaluate_loss(val_data)on validation data. - Track whether the validation loss is improving.
- Apply early stopping when validation loss stops decreasing.
Key Considerations
- Convergence monitoring: The ELBO returned by
step()is a noisy Monte Carlo estimate. Smoothing (e.g., exponential moving average) is recommended for assessing convergence trends. - Learning rate scheduling: Pyro supports learning rate schedulers (e.g.,
PyroLRScheduler) that can be stepped alongside the SVI loop to reduce the learning rate over time for finer convergence. - Early stopping: Since the ELBO is stochastic, early stopping criteria should be based on smoothed loss values with a patience window to avoid premature termination due to noise.
- Parameter initialization: Calling
pyro.clear_param_store()before constructing SVI ensures a clean slate. Failing to do so can carry over stale parameters from previous runs. - Mini-batch training: When using
pyro.platewith subsampling, eachstep()call processes a different mini-batch, and the ELBO is automatically scaled to account for the full dataset size.
Usage
The SVI training loop is used when:
- Standard variational inference: Running the core SVI optimization loop over a fixed or adaptive number of iterations to fit a variational posterior.
- Convergence monitoring: Tracking ELBO values over training to determine when the variational approximation has converged, using
evaluate_loss()for clean validation metrics. - Hyperparameter tuning: Comparing different optimizers, learning rates, or ELBO estimators by running multiple training loops and comparing convergence behavior.
- Production training pipelines: Integrating SVI into larger training workflows with learning rate scheduling, checkpointing, and early stopping.
Theoretical Basis
Stochastic Optimization for Variational Inference
The SVI training loop implements stochastic gradient ascent on the ELBO. At each iteration, a single sample (or small number of samples) from the guide provides a noisy but unbiased estimate of the ELBO gradient. The Robbins-Monro conditions for stochastic approximation guarantee convergence to a local optimum when the learning rate schedule satisfies:
- Sum of learning rates diverges (ensures exploration).
- Sum of squared learning rates converges (ensures convergence).
In practice, adaptive optimizers like Adam satisfy these conditions approximately and provide robust convergence for most models.
Loss-Only Evaluation
The evaluate_loss() method computes:
- L = -ELBO = -E_q(z)[log p(x, z) - log q(z)]
under torch.no_grad(), which disables gradient tracking. This is mathematically identical to the loss computed during training but without the side effects of gradient accumulation and parameter updates. It provides an unbiased estimate of the current ELBO suitable for monitoring.