Implementation:Pyro ppl Pyro SVI Engine
Metadata
| Field | Value |
|---|---|
| Page Type | Implementation (API Doc) |
| Knowledge Sources | Repo (Pyro) |
| Domains | Variational_Inference, Bayesian_Inference |
| Last Updated | 2026-02-09 12:00 GMT |
Overview
Concrete tool for performing stochastic variational inference provided by the Pyro probabilistic programming library.
Description
SVI (Stochastic Variational Inference) is the main class that orchestrates ELBO-based variational inference in Pyro. It extends TracePosterior and provides a unified interface for combining a probabilistic model, a variational guide, an ELBO loss, and a stochastic optimizer into an iterative training loop.
The class manages the full optimization cycle:
- Parameter discovery: Uses
poutine.trace(param_only=True)to identify allpyro.paramsites registered during guide and model execution. - Loss and gradient computation: Delegates to the configured loss function (e.g.,
Trace_ELBO) to compute the ELBO estimate and backpropagate gradients through all parameters. - Parameter collection: Gathers all unconstrained parameter values from the
ParamStoreDict. - Optimizer step: Calls the
PyroOptimwrapper, which applies the underlying PyTorch optimizer to update parameter values. - Gradient zeroing: Zeros out gradients on all parameters to prepare for the next iteration.
The class also provides evaluate_loss(), which computes the ELBO without computing or applying gradients, useful for monitoring convergence on validation data.
Code Reference
Source Location
- Repository
pyro-ppl/pyro- File
pyro/infer/svi.py- Lines
- L16--162
Signature
class SVI(TracePosterior):
def __init__(self, model, guide, optim, loss, loss_and_grads=None,
num_samples=0, num_steps=0, **kwargs):
Import
from pyro.infer import SVI
I/O Contract
Constructor Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
model |
callable | Yes | A Pyro model function defining the joint distribution over observed and latent variables via pyro.sample and pyro.param statements.
|
guide |
callable | Yes | A Pyro guide function defining the variational approximation. Must contain pyro.sample sites matching every unobserved site in the model.
|
optim |
PyroOptim | Yes | A Pyro optimizer wrapper (e.g., pyro.optim.Adam) that manages parameter-level optimizer state.
|
loss |
ELBO instance | Yes | An ELBO loss object (e.g., Trace_ELBO(), TraceGraph_ELBO()) that defines how to estimate the evidence lower bound.
|
loss_and_grads |
callable / None | No | Optional custom function for computing loss and gradients. If None, defaults to loss.loss_and_grads.
|
num_samples |
int | No | Number of samples for the TracePosterior interface. Defaults to 0.
|
num_steps |
int | No | Number of steps for the TracePosterior interface. Defaults to 0.
|
Key Methods
| Method | Signature | Returns | Description |
|---|---|---|---|
step |
step(*args, **kwargs) |
float |
Takes a single gradient step: computes loss, backpropagates, updates parameters, zeros gradients. Returns the loss estimate. |
evaluate_loss |
evaluate_loss(*args, **kwargs) |
float |
Evaluates the ELBO loss without gradient computation (uses torch.no_grad() context). Returns the loss estimate.
|
Outputs
| Output | Type | Description |
|---|---|---|
| SVI instance | SVI |
An object with step() and evaluate_loss() methods for running variational inference.
|
Usage Examples
Basic SVI Training Loop
import pyro
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
# Clear the parameter store before training
pyro.clear_param_store()
# Initialize SVI with model, guide, optimizer, and loss
svi = SVI(model, guide, Adam({"lr": 0.01}), Trace_ELBO())
# Run training loop
for step in range(1000):
loss = svi.step(data)
if step % 100 == 0:
print(f"Step {step} : loss = {loss:.4f}")
Training with Validation Monitoring
import pyro
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import ClippedAdam
pyro.clear_param_store()
svi = SVI(model, guide, ClippedAdam({"lr": 0.005}), Trace_ELBO())
for step in range(2000):
train_loss = svi.step(train_data)
if step % 100 == 0:
# Evaluate loss without computing gradients
val_loss = svi.evaluate_loss(val_data)
print(f"Step {step} : train loss = {train_loss:.4f}, val loss = {val_loss:.4f}")