Implementation:Pyro ppl Pyro TraceGraph ELBO
Overview
The tracegraph_elbo module (Template:Code) implements TraceGraph_ELBO, an ELBO-based stochastic variational inference objective that exploits fine-grained conditional dependency information recorded in Pyro execution traces to reduce gradient estimator variance. It follows the framework of "Stochastic Computation Graphs" (Schulman et al.) specialized to the ELBO.
The key innovation over basic Template:Code is the use of Rao-Blackwellization: for each non-reparameterizable sample site, only the downstream cost terms that actually depend on that site contribute to its REINFORCE gradient. This reduces variance compared to using the full ELBO as the cost for every site.
The module provides:
- TraceGraph_ELBO -- The main ELBO implementation with graph-based variance reduction.
- JitTraceGraph_ELBO -- A JIT-compiled variant using Template:Code.
- TrackNonReparam -- A Messenger that tracks non-reparameterizable sample sites using provenance tracking.
The module also contains internal functions for baseline construction (neural network baselines, decaying average baselines, and user-provided baseline values) and downstream cost computation.
Code Reference
File: Template:Code
Key Classes
| Class | Parent | Description |
|---|---|---|
| Template:Code | Template:Code | ELBO implementation using dependency-graph-based variance reduction for non-reparameterizable sites. |
| Template:Code | Template:Code | JIT-compiled version. Requires static model structure and all tensor inputs via Template:Code. |
| Template:Code | Template:Code | Effect handler that annotates values from non-reparameterizable sample sites with provenance information. |
Key Internal Functions
| Function | Description |
|---|---|
| Template:Code | Extracts baseline configuration from Template:Code. |
| Template:Code | Constructs baselines (neural network, decaying average, or user-provided) to reduce REINFORCE variance. |
| Template:Code | Recursively computes downstream cost terms for all sample sites, implementing Rao-Blackwellization. |
| Template:Code | Computes both the ELBO and a surrogate ELBO used for gradient estimation. Handles reparameterized and non-reparameterized sites differently. |
TraceGraph_ELBO Methods
| Method | Description |
|---|---|
| Template:Code | Returns an estimate of the ELBO (as a float). No gradients are computed. |
| Template:Code | Computes the ELBO and the surrogate loss, then performs backward on the surrogate. Returns the loss as a float. |
| Template:Code | Returns paired model and guide traces with provenance tracking via Template:Code. |
I/O Contract
TraceGraph_ELBO Constructor
Inputs (inherited from ELBO):
- Template:Code -- Number of samples for the gradient estimator (default 1).
- Template:Code -- Max nested plate depth (default infinity, auto-detected).
- Template:Code -- Whether to vectorize particles (default False).
- Template:Code -- Whether to retain the computation graph after backward.
loss
Inputs:
- Template:Code -- A Pyro model callable.
- Template:Code -- A Pyro guide callable.
- Template:Code -- Passed to model and guide.
Output:
- Template:Code -- Negative ELBO estimate.
loss_and_grads
Inputs:
- Same as Template:Code.
Output:
- Template:Code -- Negative ELBO estimate. Also populates gradients on model/guide parameters.
Baseline Options (per site)
Baselines are configured via Template:Code}} in guide sample sites:
- Template:Code -- A neural network module.
- Template:Code -- Input tensor for the neural network baseline.
- Template:Code -- Use an exponential moving average baseline.
- Template:Code -- Decay rate for the moving average (default 0.90).
- Template:Code -- A user-provided baseline tensor.
Usage Examples
Basic Usage with SVI
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, TraceGraph_ELBO
from pyro.optim import Adam
def model(data):
p = pyro.sample("p", dist.Beta(1, 1))
with pyro.plate("data", len(data)):
pyro.sample("obs", dist.Bernoulli(p), obs=data)
def guide(data):
alpha_q = pyro.param("alpha_q", torch.tensor(1.0),
constraint=dist.constraints.positive)
beta_q = pyro.param("beta_q", torch.tensor(1.0),
constraint=dist.constraints.positive)
pyro.sample("p", dist.Beta(alpha_q, beta_q))
svi = SVI(model, guide, Adam({"lr": 0.01}), loss=TraceGraph_ELBO())
for step in range(1000):
loss = svi.step(data)
Using Baselines for Variance Reduction
def guide(data):
# Use a decaying average baseline for a discrete latent variable
probs = pyro.param("probs", torch.tensor([0.3, 0.7]))
pyro.sample("z", dist.Categorical(probs),
infer={"baseline": {"use_decaying_avg_baseline": True,
"baseline_beta": 0.95}})
JIT-Compiled Variant
from pyro.infer import JitTraceGraph_ELBO
svi = SVI(model, guide, Adam({"lr": 0.01}), loss=JitTraceGraph_ELBO())
for step in range(1000):
loss = svi.step(data) # First call triggers JIT compilation
Related Pages
- Pyro_ppl_Pyro_Infer_Utilities -- Contains Template:Code, Template:Code, and other utilities used by TraceGraph_ELBO
- Pyro_ppl_Pyro_RenyiELBO -- Alternative ELBO objective using Renyi divergence
- Pyro_ppl_Pyro_TraceTailAdaptive_ELBO -- Tail-adaptive f-divergence variant
- Pyro_ppl_Pyro_TraceTMC_ELBO -- Tensor Monte Carlo ELBO implementation
- Pyro_ppl_Pyro_Model_Inspect -- Tools for analyzing model dependency structure