Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Pyro ppl Pyro SVI Engine

From Leeroopedia


Metadata

Field Value
Page Type Implementation (API Doc)
Knowledge Sources Repo (Pyro)
Domains Variational_Inference, Bayesian_Inference
Last Updated 2026-02-09 12:00 GMT

Overview

Concrete tool for performing stochastic variational inference provided by the Pyro probabilistic programming library.

Description

SVI (Stochastic Variational Inference) is the main class that orchestrates ELBO-based variational inference in Pyro. It extends TracePosterior and provides a unified interface for combining a probabilistic model, a variational guide, an ELBO loss, and a stochastic optimizer into an iterative training loop.

The class manages the full optimization cycle:

  1. Parameter discovery: Uses poutine.trace(param_only=True) to identify all pyro.param sites registered during guide and model execution.
  2. Loss and gradient computation: Delegates to the configured loss function (e.g., Trace_ELBO) to compute the ELBO estimate and backpropagate gradients through all parameters.
  3. Parameter collection: Gathers all unconstrained parameter values from the ParamStoreDict.
  4. Optimizer step: Calls the PyroOptim wrapper, which applies the underlying PyTorch optimizer to update parameter values.
  5. Gradient zeroing: Zeros out gradients on all parameters to prepare for the next iteration.

The class also provides evaluate_loss(), which computes the ELBO without computing or applying gradients, useful for monitoring convergence on validation data.

Code Reference

Source Location

Repository
pyro-ppl/pyro
File
pyro/infer/svi.py
Lines
L16--162

Signature

class SVI(TracePosterior):
    def __init__(self, model, guide, optim, loss, loss_and_grads=None,
                 num_samples=0, num_steps=0, **kwargs):

Import

from pyro.infer import SVI

I/O Contract

Constructor Inputs

Parameter Type Required Description
model callable Yes A Pyro model function defining the joint distribution over observed and latent variables via pyro.sample and pyro.param statements.
guide callable Yes A Pyro guide function defining the variational approximation. Must contain pyro.sample sites matching every unobserved site in the model.
optim PyroOptim Yes A Pyro optimizer wrapper (e.g., pyro.optim.Adam) that manages parameter-level optimizer state.
loss ELBO instance Yes An ELBO loss object (e.g., Trace_ELBO(), TraceGraph_ELBO()) that defines how to estimate the evidence lower bound.
loss_and_grads callable / None No Optional custom function for computing loss and gradients. If None, defaults to loss.loss_and_grads.
num_samples int No Number of samples for the TracePosterior interface. Defaults to 0.
num_steps int No Number of steps for the TracePosterior interface. Defaults to 0.

Key Methods

Method Signature Returns Description
step step(*args, **kwargs) float Takes a single gradient step: computes loss, backpropagates, updates parameters, zeros gradients. Returns the loss estimate.
evaluate_loss evaluate_loss(*args, **kwargs) float Evaluates the ELBO loss without gradient computation (uses torch.no_grad() context). Returns the loss estimate.

Outputs

Output Type Description
SVI instance SVI An object with step() and evaluate_loss() methods for running variational inference.

Usage Examples

Basic SVI Training Loop

import pyro
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

# Clear the parameter store before training
pyro.clear_param_store()

# Initialize SVI with model, guide, optimizer, and loss
svi = SVI(model, guide, Adam({"lr": 0.01}), Trace_ELBO())

# Run training loop
for step in range(1000):
    loss = svi.step(data)
    if step % 100 == 0:
        print(f"Step {step} : loss = {loss:.4f}")

Training with Validation Monitoring

import pyro
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import ClippedAdam

pyro.clear_param_store()

svi = SVI(model, guide, ClippedAdam({"lr": 0.005}), Trace_ELBO())

for step in range(2000):
    train_loss = svi.step(train_data)

    if step % 100 == 0:
        # Evaluate loss without computing gradients
        val_loss = svi.evaluate_loss(val_data)
        print(f"Step {step} : train loss = {train_loss:.4f}, val loss = {val_loss:.4f}")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment