Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro TraceGraph ELBO

From Leeroopedia
Revision as of 16:26, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Pyro_ppl_Pyro_TraceGraph_ELBO.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Overview

The tracegraph_elbo module (Template:Code) implements TraceGraph_ELBO, an ELBO-based stochastic variational inference objective that exploits fine-grained conditional dependency information recorded in Pyro execution traces to reduce gradient estimator variance. It follows the framework of "Stochastic Computation Graphs" (Schulman et al.) specialized to the ELBO.

The key innovation over basic Template:Code is the use of Rao-Blackwellization: for each non-reparameterizable sample site, only the downstream cost terms that actually depend on that site contribute to its REINFORCE gradient. This reduces variance compared to using the full ELBO as the cost for every site.

The module provides:

  • TraceGraph_ELBO -- The main ELBO implementation with graph-based variance reduction.
  • JitTraceGraph_ELBO -- A JIT-compiled variant using Template:Code.
  • TrackNonReparam -- A Messenger that tracks non-reparameterizable sample sites using provenance tracking.

The module also contains internal functions for baseline construction (neural network baselines, decaying average baselines, and user-provided baseline values) and downstream cost computation.

Code Reference

File: Template:Code

Key Classes

Class Parent Description
Template:Code Template:Code ELBO implementation using dependency-graph-based variance reduction for non-reparameterizable sites.
Template:Code Template:Code JIT-compiled version. Requires static model structure and all tensor inputs via Template:Code.
Template:Code Template:Code Effect handler that annotates values from non-reparameterizable sample sites with provenance information.

Key Internal Functions

Function Description
Template:Code Extracts baseline configuration from Template:Code.
Template:Code Constructs baselines (neural network, decaying average, or user-provided) to reduce REINFORCE variance.
Template:Code Recursively computes downstream cost terms for all sample sites, implementing Rao-Blackwellization.
Template:Code Computes both the ELBO and a surrogate ELBO used for gradient estimation. Handles reparameterized and non-reparameterized sites differently.

TraceGraph_ELBO Methods

Method Description
Template:Code Returns an estimate of the ELBO (as a float). No gradients are computed.
Template:Code Computes the ELBO and the surrogate loss, then performs backward on the surrogate. Returns the loss as a float.
Template:Code Returns paired model and guide traces with provenance tracking via Template:Code.

I/O Contract

TraceGraph_ELBO Constructor

Inputs (inherited from ELBO):

  • Template:Code -- Number of samples for the gradient estimator (default 1).
  • Template:Code -- Max nested plate depth (default infinity, auto-detected).
  • Template:Code -- Whether to vectorize particles (default False).
  • Template:Code -- Whether to retain the computation graph after backward.

loss

Inputs:

Output:

loss_and_grads

Inputs:

Output:

  • Template:Code -- Negative ELBO estimate. Also populates gradients on model/guide parameters.

Baseline Options (per site)

Baselines are configured via Template:Code}} in guide sample sites:

Usage Examples

Basic Usage with SVI

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, TraceGraph_ELBO
from pyro.optim import Adam

def model(data):
    p = pyro.sample("p", dist.Beta(1, 1))
    with pyro.plate("data", len(data)):
        pyro.sample("obs", dist.Bernoulli(p), obs=data)

def guide(data):
    alpha_q = pyro.param("alpha_q", torch.tensor(1.0),
                         constraint=dist.constraints.positive)
    beta_q = pyro.param("beta_q", torch.tensor(1.0),
                        constraint=dist.constraints.positive)
    pyro.sample("p", dist.Beta(alpha_q, beta_q))

svi = SVI(model, guide, Adam({"lr": 0.01}), loss=TraceGraph_ELBO())
for step in range(1000):
    loss = svi.step(data)

Using Baselines for Variance Reduction

def guide(data):
    # Use a decaying average baseline for a discrete latent variable
    probs = pyro.param("probs", torch.tensor([0.3, 0.7]))
    pyro.sample("z", dist.Categorical(probs),
                infer={"baseline": {"use_decaying_avg_baseline": True,
                                    "baseline_beta": 0.95}})

JIT-Compiled Variant

from pyro.infer import JitTraceGraph_ELBO

svi = SVI(model, guide, Adam({"lr": 0.01}), loss=JitTraceGraph_ELBO())
for step in range(1000):
    loss = svi.step(data)  # First call triggers JIT compilation

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment