Implementation:Pyro ppl Pyro Trace ELBO Loss
| Field | Value |
|---|---|
| Sources | Pyro |
| Domains | Variational_Inference, Optimization |
| Last Updated | 2026-02-09 12:00 GMT |
Overview
Trace_ELBO is a concrete implementation of the Evidence Lower Bound (ELBO) estimator for stochastic variational inference in Pyro, using Monte Carlo sampling of execution traces to estimate the ELBO and its gradients.
Description
The Trace_ELBO class implements the standard single-sample Monte Carlo ELBO estimator. It runs the guide (approximate posterior) to generate a trace of sampled latent variables, then replays the model against that trace to compute log probabilities. The ELBO is estimated as the difference between the model's total log probability and the guide's total log probability across the trace.
The estimator includes partial Rao-Blackwellization for reducing gradient variance when non-reparameterizable random variables are present. This partial Rao-Blackwellization uses conditional independence information marked by pyro.plate contexts. For more fine-grained variance reduction, users should consider TraceGraph_ELBO.
The class supports three primary interfaces:
loss(): Returns a scalar ELBO estimate (detached from the computation graph)differentiable_loss(): Returns a differentiable surrogate loss for use with PyTorch autogradloss_and_grads(): Computes the ELBO and performs backward pass in a single call__call__(model, guide): Returns anELBOModulefor PyTorch-native training with standardtorch.optimoptimizers
Usage
Import Trace_ELBO to use as the loss function for Pyro's SVI class, or call it directly with a model-guide pair to obtain a PyTorch Module.
Code Reference
Source Location
- Repository
pyro-ppl/pyro- File
pyro/infer/trace_elbo.py- Lines
- L32--159
- Base class
pyro/infer/elbo.pyL30--239
Signature
class Trace_ELBO(ELBO):
def __init__(
self,
num_particles=1,
max_plate_nesting=float('inf'),
vectorize_particles=False,
strict_enumeration_warning=True,
ignore_jit_warnings=False,
jit_options=None,
retain_graph=None,
tail_adaptive_beta=-1.0,
):
Import
from pyro.infer import Trace_ELBO
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
num_particles |
int | No | Number of particles/samples used to form the ELBO estimator (default: 1) |
max_plate_nesting |
int or float | No | Bound on max number of nested pyro.plate contexts (default: inf, auto-detected)
|
vectorize_particles |
bool | No | Whether to vectorize ELBO computation over particles using a plate dimension (default: False) |
strict_enumeration_warning |
bool | No | Whether to warn about possible misuse of enumeration (default: True) |
ignore_jit_warnings |
bool | No | Flag to ignore JIT tracer warnings (default: False) |
jit_options |
dict or None | No | Options to pass to torch.jit.trace (default: None)
|
retain_graph |
bool or None | No | Whether to retain autograd graph during backward pass (default: None) |
tail_adaptive_beta |
float | No | Exponent for tail-adaptive ELBO variant (default: -1.0) |
Outputs
| Method | Return Type | Description |
|---|---|---|
loss(model, guide, *args, **kwargs) |
float | Scalar estimate of the negative ELBO (detached) |
differentiable_loss(model, guide, *args, **kwargs) |
torch.Tensor | Differentiable surrogate loss for autograd |
loss_and_grads(model, guide, *args, **kwargs) |
float | Scalar ELBO estimate; also performs backward pass on surrogate loss |
__call__(model, guide) |
ELBOModule | PyTorch Module wrapping model, guide, and ELBO for native training |
Usage Examples
Standard SVI Usage
import pyro
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
# Define model and guide
def model(data):
loc = pyro.param("loc", torch.tensor(0.0))
scale = pyro.param("scale", torch.tensor(1.0),
constraint=pyro.distributions.constraints.positive)
with pyro.plate("data", len(data)):
pyro.sample("obs", pyro.distributions.Normal(loc, scale), obs=data)
def guide(data):
pass # No latent variables in this simple example
# Create SVI with Trace_ELBO
svi = SVI(model, guide, Adam({"lr": 0.01}), loss=Trace_ELBO(num_particles=10))
for step in range(1000):
loss = svi.step(data)
PyTorch-Native Training
import pyro
from pyro.infer import Trace_ELBO
elbo = Trace_ELBO(num_particles=10)(model, guide)
optimizer = torch.optim.Adam(elbo.parameters(), lr=0.001)
for step in range(1000):
optimizer.zero_grad()
loss = elbo(data)
loss.backward()
optimizer.step()