Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Pyro ppl Pyro Trace ELBO Loss

From Leeroopedia


Field Value
Sources Pyro
Domains Variational_Inference, Optimization
Last Updated 2026-02-09 12:00 GMT

Overview

Trace_ELBO is a concrete implementation of the Evidence Lower Bound (ELBO) estimator for stochastic variational inference in Pyro, using Monte Carlo sampling of execution traces to estimate the ELBO and its gradients.

Description

The Trace_ELBO class implements the standard single-sample Monte Carlo ELBO estimator. It runs the guide (approximate posterior) to generate a trace of sampled latent variables, then replays the model against that trace to compute log probabilities. The ELBO is estimated as the difference between the model's total log probability and the guide's total log probability across the trace.

The estimator includes partial Rao-Blackwellization for reducing gradient variance when non-reparameterizable random variables are present. This partial Rao-Blackwellization uses conditional independence information marked by pyro.plate contexts. For more fine-grained variance reduction, users should consider TraceGraph_ELBO.

The class supports three primary interfaces:

  • loss(): Returns a scalar ELBO estimate (detached from the computation graph)
  • differentiable_loss(): Returns a differentiable surrogate loss for use with PyTorch autograd
  • loss_and_grads(): Computes the ELBO and performs backward pass in a single call
  • __call__(model, guide): Returns an ELBOModule for PyTorch-native training with standard torch.optim optimizers

Usage

Import Trace_ELBO to use as the loss function for Pyro's SVI class, or call it directly with a model-guide pair to obtain a PyTorch Module.

Code Reference

Source Location

Repository
pyro-ppl/pyro
File
pyro/infer/trace_elbo.py
Lines
L32--159
Base class
pyro/infer/elbo.py L30--239

Signature

class Trace_ELBO(ELBO):
    def __init__(
        self,
        num_particles=1,
        max_plate_nesting=float('inf'),
        vectorize_particles=False,
        strict_enumeration_warning=True,
        ignore_jit_warnings=False,
        jit_options=None,
        retain_graph=None,
        tail_adaptive_beta=-1.0,
    ):

Import

from pyro.infer import Trace_ELBO

I/O Contract

Inputs

Name Type Required Description
num_particles int No Number of particles/samples used to form the ELBO estimator (default: 1)
max_plate_nesting int or float No Bound on max number of nested pyro.plate contexts (default: inf, auto-detected)
vectorize_particles bool No Whether to vectorize ELBO computation over particles using a plate dimension (default: False)
strict_enumeration_warning bool No Whether to warn about possible misuse of enumeration (default: True)
ignore_jit_warnings bool No Flag to ignore JIT tracer warnings (default: False)
jit_options dict or None No Options to pass to torch.jit.trace (default: None)
retain_graph bool or None No Whether to retain autograd graph during backward pass (default: None)
tail_adaptive_beta float No Exponent for tail-adaptive ELBO variant (default: -1.0)

Outputs

Method Return Type Description
loss(model, guide, *args, **kwargs) float Scalar estimate of the negative ELBO (detached)
differentiable_loss(model, guide, *args, **kwargs) torch.Tensor Differentiable surrogate loss for autograd
loss_and_grads(model, guide, *args, **kwargs) float Scalar ELBO estimate; also performs backward pass on surrogate loss
__call__(model, guide) ELBOModule PyTorch Module wrapping model, guide, and ELBO for native training

Usage Examples

Standard SVI Usage

import pyro
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

# Define model and guide
def model(data):
    loc = pyro.param("loc", torch.tensor(0.0))
    scale = pyro.param("scale", torch.tensor(1.0),
                        constraint=pyro.distributions.constraints.positive)
    with pyro.plate("data", len(data)):
        pyro.sample("obs", pyro.distributions.Normal(loc, scale), obs=data)

def guide(data):
    pass  # No latent variables in this simple example

# Create SVI with Trace_ELBO
svi = SVI(model, guide, Adam({"lr": 0.01}), loss=Trace_ELBO(num_particles=10))

for step in range(1000):
    loss = svi.step(data)

PyTorch-Native Training

import pyro
from pyro.infer import Trace_ELBO

elbo = Trace_ELBO(num_particles=10)(model, guide)

optimizer = torch.optim.Adam(elbo.parameters(), lr=0.001)
for step in range(1000):
    optimizer.zero_grad()
    loss = elbo(data)
    loss.backward()
    optimizer.step()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment