Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Pyro ppl Pyro SVI Training Loop

From Leeroopedia


Metadata

Field Value
Page Type Principle
Knowledge Sources Repo (Pyro)
Domains Variational_Inference, Optimization
Last Updated 2026-02-09 12:00 GMT

Overview

The iterative training loop pattern for Stochastic Variational Inference, where each iteration computes the ELBO, backpropagates gradients, and updates variational parameters through the step() method, with optional loss-only evaluation via evaluate_loss() for convergence monitoring.

Description

The SVI training loop is the standard pattern for performing variational inference in Pyro. It centers on two complementary operations provided by the SVI class: step() for training and evaluate_loss() for monitoring.

The step() Method: Training

The step() method is the workhorse of the SVI training loop. Each call performs a complete optimization iteration that both computes the ELBO loss and applies parameter updates in a single call. Internally, it executes the following sequence:

  1. Trace parameter sites: Uses poutine.trace(param_only=True) to discover all parameter sites by running the loss computation. This captures every pyro.param site accessed during model and guide execution.
  2. Compute loss and gradients: Calls self.loss_and_grads(model, guide, *args, **kwargs), which estimates the ELBO and backpropagates gradients through all registered parameters in a single operation.
  3. Collect unconstrained parameters: Gathers all parameter values from the Pyro parameter store in their unconstrained (transformed) form, which is the space in which optimization occurs.
  4. Apply optimizer: Calls self.optim(params), which applies the configured PyTorch optimizer (e.g., Adam) to update parameter values based on accumulated gradients.
  5. Zero gradients: Calls zero_grads(params) to reset all gradients to zero, preparing for the next iteration.

The method returns the scalar loss estimate as a float, enabling the caller to track training progress.

The evaluate_loss() Method: Monitoring

The evaluate_loss() method computes the ELBO loss without computing or applying gradients. It wraps the loss computation in a torch.no_grad() context, which:

  • Avoids the computational cost of building the autograd graph.
  • Prevents any parameter updates from occurring.
  • Provides a clean estimate of the current ELBO for monitoring purposes.

This method is essential for convergence monitoring and validation, as it allows measuring the ELBO on held-out data without affecting the training state.

Training Loop Patterns

The SVI training loop typically follows one of these patterns:

Basic Loop

The simplest pattern iterates for a fixed number of steps, calling step() each iteration:

  • Initialize the parameter store with pyro.clear_param_store().
  • Construct the SVI object with model, guide, optimizer, and loss.
  • Loop for N iterations, calling svi.step(data) and optionally logging the returned loss.

Loop with Validation

A more robust pattern monitors convergence using evaluate_loss():

  • Periodically compute svi.evaluate_loss(val_data) on validation data.
  • Track whether the validation loss is improving.
  • Apply early stopping when validation loss stops decreasing.

Key Considerations

  • Convergence monitoring: The ELBO returned by step() is a noisy Monte Carlo estimate. Smoothing (e.g., exponential moving average) is recommended for assessing convergence trends.
  • Learning rate scheduling: Pyro supports learning rate schedulers (e.g., PyroLRScheduler) that can be stepped alongside the SVI loop to reduce the learning rate over time for finer convergence.
  • Early stopping: Since the ELBO is stochastic, early stopping criteria should be based on smoothed loss values with a patience window to avoid premature termination due to noise.
  • Parameter initialization: Calling pyro.clear_param_store() before constructing SVI ensures a clean slate. Failing to do so can carry over stale parameters from previous runs.
  • Mini-batch training: When using pyro.plate with subsampling, each step() call processes a different mini-batch, and the ELBO is automatically scaled to account for the full dataset size.

Usage

The SVI training loop is used when:

  • Standard variational inference: Running the core SVI optimization loop over a fixed or adaptive number of iterations to fit a variational posterior.
  • Convergence monitoring: Tracking ELBO values over training to determine when the variational approximation has converged, using evaluate_loss() for clean validation metrics.
  • Hyperparameter tuning: Comparing different optimizers, learning rates, or ELBO estimators by running multiple training loops and comparing convergence behavior.
  • Production training pipelines: Integrating SVI into larger training workflows with learning rate scheduling, checkpointing, and early stopping.

Theoretical Basis

Stochastic Optimization for Variational Inference

The SVI training loop implements stochastic gradient ascent on the ELBO. At each iteration, a single sample (or small number of samples) from the guide provides a noisy but unbiased estimate of the ELBO gradient. The Robbins-Monro conditions for stochastic approximation guarantee convergence to a local optimum when the learning rate schedule satisfies:

Sum of learning rates diverges (ensures exploration).
Sum of squared learning rates converges (ensures convergence).

In practice, adaptive optimizers like Adam satisfy these conditions approximately and provide robust convergence for most models.

Loss-Only Evaluation

The evaluate_loss() method computes:

L = -ELBO = -E_q(z)[log p(x, z) - log q(z)]

under torch.no_grad(), which disables gradient tracking. This is mathematically identical to the loss computed during training but without the side effects of gradient accumulation and parameter updates. It provides an unbiased estimate of the current ELBO suitable for monitoring.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment