Implementation:Pyro ppl Pyro SVI Step
Metadata
| Field | Value |
|---|---|
| Page Type | Implementation (API Doc) |
| Knowledge Sources | Repo (Pyro) |
| Domains | Variational_Inference, Optimization |
| Last Updated | 2026-02-09 12:00 GMT |
Overview
Concrete tools for executing a single SVI training step and evaluating the ELBO loss without gradients, provided as methods on the Pyro SVI class.
Description
The step() and evaluate_loss() methods on the SVI class implement the two core operations of the SVI training loop: gradient-based parameter updates and gradient-free loss evaluation, respectively.
step() Method
The step() method performs one complete optimization iteration. Its internal execution proceeds as follows:
- Trace parameter sites: Wraps the loss computation in
poutine.trace(param_only=True)to capture allpyro.paramsites that are accessed during model and guide execution. This is necessary because Pyro parameters may be created dynamically. - Compute loss and backpropagate: Calls
self.loss_and_grads(self.model, self.guide, *args, **kwargs), which runs the model and guide, computes the negative ELBO, and calls.backward()to populate.gradattributes on all parameter tensors. - Collect unconstrained parameters: Iterates over all traced parameter sites and retrieves their unconstrained values from the Pyro parameter store. The unconstrained values are the raw tensors on which the optimizer operates (before any constraint transforms are applied).
- Apply optimizer: Calls
self.optim(params), whereparamsis a dictionary mapping parameter names to unconstrained tensors. ThePyroOptimwrapper creates or retrieves per-parameter optimizer state and calls the underlying PyTorch optimizer's step. - Zero gradients: Calls
zero_grads(params)to set all.gradattributes to zero, preventing gradient accumulation across iterations.
The method returns the loss as a Python float.
evaluate_loss() Method
The evaluate_loss() method computes the ELBO without building the computation graph or updating any parameters. It wraps the loss computation in a torch.no_grad() context manager, which:
- Disables gradient computation, reducing memory usage and computation time.
- Prevents any
.gradattributes from being populated. - Leaves the parameter store and optimizer state completely unchanged.
This method accepts the same *args and **kwargs as step(), which are forwarded to the model and guide callables. It returns the loss as a Python float.
Code Reference
Source Location
- Repository
pyro-ppl/pyro- File
pyro/infer/svi.py- Lines (step)
- L134--162
- Lines (evaluate_loss)
- L119--132
Signatures
# step method
def step(self, *args, **kwargs):
"""
Take a gradient step on the loss function (and target any learned
parameters in both the model and guide).
:param args: arguments to the model and guide.
:param kwargs: keyword arguments to the model and guide.
:returns: estimate of the loss
:rtype: float
"""
# evaluate_loss method
def evaluate_loss(self, *args, **kwargs):
"""
Evaluate the loss without taking a gradient step.
:param args: arguments to the model and guide.
:param kwargs: keyword arguments to the model and guide.
:returns: estimate of the loss
:rtype: float
"""
Import
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
# Access via SVI instance
svi = SVI(model, guide, Adam({"lr": 0.01}), Trace_ELBO())
loss = svi.step(data) # training step
val_loss = svi.evaluate_loss(data) # evaluation only
I/O Contract
step() Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
*args |
any | No | Positional arguments forwarded to both the model and guide callables (e.g., observed data tensors). |
**kwargs |
any | No | Keyword arguments forwarded to both the model and guide callables. |
step() Outputs
| Output | Type | Description |
|---|---|---|
| loss | float |
A scalar estimate of the ELBO loss (negated ELBO). Lower values indicate better fit. |
evaluate_loss() Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
*args |
any | No | Positional arguments forwarded to both the model and guide callables (e.g., validation data tensors). |
**kwargs |
any | No | Keyword arguments forwarded to both the model and guide callables. |
evaluate_loss() Outputs
| Output | Type | Description |
|---|---|---|
| loss | float |
A scalar estimate of the ELBO loss computed without gradient tracking. Useful for validation and convergence monitoring. |
Usage Examples
Basic Training Loop with step()
import pyro
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
pyro.clear_param_store()
svi = SVI(model, guide, Adam({"lr": 0.01}), Trace_ELBO())
for step in range(1000):
loss = svi.step(data)
if step % 100 == 0:
print(f"Step {step} : loss = {loss:.4f}")
Convergence Monitoring with evaluate_loss()
import pyro
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
pyro.clear_param_store()
svi = SVI(model, guide, Adam({"lr": 0.01}), Trace_ELBO())
best_val_loss = float("inf")
patience_counter = 0
for step in range(5000):
train_loss = svi.step(train_data)
if step % 50 == 0:
val_loss = svi.evaluate_loss(val_data)
print(f"Step {step} : train = {train_loss:.4f}, val = {val_loss:.4f}")
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= 10:
print("Early stopping.")
break