Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Pyro ppl Pyro SVI Step

From Leeroopedia


Metadata

Field Value
Page Type Implementation (API Doc)
Knowledge Sources Repo (Pyro)
Domains Variational_Inference, Optimization
Last Updated 2026-02-09 12:00 GMT

Overview

Concrete tools for executing a single SVI training step and evaluating the ELBO loss without gradients, provided as methods on the Pyro SVI class.

Description

The step() and evaluate_loss() methods on the SVI class implement the two core operations of the SVI training loop: gradient-based parameter updates and gradient-free loss evaluation, respectively.

step() Method

The step() method performs one complete optimization iteration. Its internal execution proceeds as follows:

  1. Trace parameter sites: Wraps the loss computation in poutine.trace(param_only=True) to capture all pyro.param sites that are accessed during model and guide execution. This is necessary because Pyro parameters may be created dynamically.
  2. Compute loss and backpropagate: Calls self.loss_and_grads(self.model, self.guide, *args, **kwargs), which runs the model and guide, computes the negative ELBO, and calls .backward() to populate .grad attributes on all parameter tensors.
  3. Collect unconstrained parameters: Iterates over all traced parameter sites and retrieves their unconstrained values from the Pyro parameter store. The unconstrained values are the raw tensors on which the optimizer operates (before any constraint transforms are applied).
  4. Apply optimizer: Calls self.optim(params), where params is a dictionary mapping parameter names to unconstrained tensors. The PyroOptim wrapper creates or retrieves per-parameter optimizer state and calls the underlying PyTorch optimizer's step.
  5. Zero gradients: Calls zero_grads(params) to set all .grad attributes to zero, preventing gradient accumulation across iterations.

The method returns the loss as a Python float.

evaluate_loss() Method

The evaluate_loss() method computes the ELBO without building the computation graph or updating any parameters. It wraps the loss computation in a torch.no_grad() context manager, which:

  • Disables gradient computation, reducing memory usage and computation time.
  • Prevents any .grad attributes from being populated.
  • Leaves the parameter store and optimizer state completely unchanged.

This method accepts the same *args and **kwargs as step(), which are forwarded to the model and guide callables. It returns the loss as a Python float.

Code Reference

Source Location

Repository
pyro-ppl/pyro
File
pyro/infer/svi.py
Lines (step)
L134--162
Lines (evaluate_loss)
L119--132

Signatures

# step method
def step(self, *args, **kwargs):
    """
    Take a gradient step on the loss function (and target any learned
    parameters in both the model and guide).

    :param args: arguments to the model and guide.
    :param kwargs: keyword arguments to the model and guide.
    :returns: estimate of the loss
    :rtype: float
    """
# evaluate_loss method
def evaluate_loss(self, *args, **kwargs):
    """
    Evaluate the loss without taking a gradient step.

    :param args: arguments to the model and guide.
    :param kwargs: keyword arguments to the model and guide.
    :returns: estimate of the loss
    :rtype: float
    """

Import

from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

# Access via SVI instance
svi = SVI(model, guide, Adam({"lr": 0.01}), Trace_ELBO())
loss = svi.step(data)           # training step
val_loss = svi.evaluate_loss(data)  # evaluation only

I/O Contract

step() Inputs

Parameter Type Required Description
*args any No Positional arguments forwarded to both the model and guide callables (e.g., observed data tensors).
**kwargs any No Keyword arguments forwarded to both the model and guide callables.

step() Outputs

Output Type Description
loss float A scalar estimate of the ELBO loss (negated ELBO). Lower values indicate better fit.

evaluate_loss() Inputs

Parameter Type Required Description
*args any No Positional arguments forwarded to both the model and guide callables (e.g., validation data tensors).
**kwargs any No Keyword arguments forwarded to both the model and guide callables.

evaluate_loss() Outputs

Output Type Description
loss float A scalar estimate of the ELBO loss computed without gradient tracking. Useful for validation and convergence monitoring.

Usage Examples

Basic Training Loop with step()

import pyro
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

pyro.clear_param_store()
svi = SVI(model, guide, Adam({"lr": 0.01}), Trace_ELBO())

for step in range(1000):
    loss = svi.step(data)
    if step % 100 == 0:
        print(f"Step {step} : loss = {loss:.4f}")

Convergence Monitoring with evaluate_loss()

import pyro
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

pyro.clear_param_store()
svi = SVI(model, guide, Adam({"lr": 0.01}), Trace_ELBO())

best_val_loss = float("inf")
patience_counter = 0

for step in range(5000):
    train_loss = svi.step(train_data)

    if step % 50 == 0:
        val_loss = svi.evaluate_loss(val_data)
        print(f"Step {step} : train = {train_loss:.4f}, val = {val_loss:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= 10:
                print("Early stopping.")
                break

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment