Implementation:Pyro ppl Pyro TraceTMC ELBO
Overview
The tracetmc_elbo module (Template:Code) implements TraceTMC_ELBO, a trace-based implementation of Tensor Monte Carlo (TMC) as described by Aitchison (2018), built on top of Tensor Variable Elimination (TVE) from Obermeyer et al. (2019). This provides a powerful framework that combines:
- Local parallel sampling -- Multiple samples drawn in parallel at any sample site in the model or guide.
- Exhaustive enumeration -- Complete enumeration over discrete sample sites.
TMC generalizes standard importance-weighted estimators by treating samples as tensor dimensions and using einsum-based contraction (via Template:Code) to compute the marginal likelihood estimate. This is more efficient than naive approaches because it exploits conditional independence structure through tensor variable elimination.
To enable parallel sampling at a site, mark it with Template:Code}. To enumerate, mark with Template:Code} or Template:Code}. The Template:Code decorator can configure all sites at once.
Structural restriction: Variables outside of a Template:Code can never depend on variables inside that Template:Code.
The module also handles DiCE factors for non-reparameterizable proposal sites, enabling correct gradient estimation for discrete variables.
Code Reference
File: Template:Code
Key Classes
| Class | Parent | Description |
|---|---|---|
| Template:Code | Template:Code | Tensor Monte Carlo ELBO using parallel sampling and exhaustive enumeration with tensor variable elimination. |
TraceTMC_ELBO Methods
| Method | Description |
|---|---|
| Template:Code | Compute a differentiable TMC estimate of the marginal log-likelihood. Returns a tensor. |
| Template:Code | Compute the TMC loss as a float (under Template:Code). |
| Template:Code | Compute the loss, perform backward, and return the loss as a float. |
| Template:Code | Returns paired model/guide traces with packed tensors and plate-to-symbol mappings. |
| Template:Code | Generator yielding traces with parallel enumeration enabled. Uses Template:Code for allocation of enumeration dimensions. |
Internal Functions
| Function | Description |
|---|---|
| Template:Code | Computes per-site log-factors: log(p/q) for unobserved sites, log(p) for observed sites, with corrections for sites sampled from the prior. |
| Template:Code | Computes DiCE log-factors for non-reparameterizable proposal sites, enabling correct gradient estimation. |
| Template:Code | Uses Template:Code with the Template:Code backend to perform tensor variable elimination and compute the TMC marginal likelihood estimate. |
I/O Contract
Constructor (inherited from ELBO)
Inputs:
- Template:Code -- Number of particles/samples. Default is 1.
- Template:Code -- Max nested plate depth. Default is infinity (auto-detected).
- Template:Code -- Whether to vectorize particles. Default is False.
- Template:Code -- Warn if no enumerated sites found. Default is True.
differentiable_loss
Inputs:
- Template:Code -- A Pyro model callable.
- Template:Code -- A Pyro guide callable.
- Template:Code -- Passed to model and guide.
Output:
- Template:Code -- A differentiable negative TMC estimate of the marginal log-likelihood.
Raises:
- Template:Code if the result is identically zero and not differentiable.
Site Configuration
To take multiple samples at a site:
pyro.sample("z", dist, infer={"enumerate": "parallel", "num_samples": 10})
To enumerate a site:
pyro.sample("z", dist, infer={"enumerate": "parallel"})
Internal TMC Computation
- Collect log-factors from model and guide (log p/q for latent sites, log p for observed sites).
- Collect DiCE factors for non-reparameterizable sites.
- Construct an einsum equation from the Template:Code annotations on each factor.
- Use Template:Code with Template:Code backend to contract, respecting plate structure.
- The result is a scalar TMC estimate.
Usage Examples
Basic TMC with Parallel Sampling
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, TraceTMC_ELBO, config_enumerate
from pyro.optim import Adam
@config_enumerate(num_samples=10)
def model(data):
z = pyro.sample("z", dist.Normal(0, 1))
pyro.sample("obs", dist.Normal(z, 1), obs=data)
def guide(data):
loc = pyro.param("loc", torch.tensor(0.0))
scale = pyro.param("scale", torch.tensor(1.0),
constraint=dist.constraints.positive)
pyro.sample("z", dist.Normal(loc, scale))
tmc = TraceTMC_ELBO(num_particles=1)
svi = SVI(model, guide, Adam({"lr": 0.01}), loss=tmc)
for step in range(1000):
loss = svi.step(data)
Enumeration over Discrete Variables
@config_enumerate
def model(data):
# Discrete variable will be fully enumerated
k = pyro.sample("k", dist.Categorical(torch.ones(3) / 3))
locs = torch.tensor([-1.0, 0.0, 1.0])
pyro.sample("obs", dist.Normal(locs[k], 1), obs=data)
def guide(data):
probs = pyro.param("probs", torch.ones(3) / 3,
constraint=dist.constraints.simplex)
pyro.sample("k", dist.Categorical(probs))
tmc = TraceTMC_ELBO()
svi = SVI(model, guide, Adam({"lr": 0.01}), loss=tmc)
Mixed Enumeration and Sampling
def model(data):
# Enumerate this discrete site
k = pyro.sample("k", dist.Categorical(torch.ones(3) / 3),
infer={"enumerate": "parallel"})
# Take multiple samples from this continuous site
z = pyro.sample("z", dist.Normal(0, 1),
infer={"enumerate": "parallel", "num_samples": 5})
pyro.sample("obs", dist.Normal(z + k, 1), obs=data)
Related Pages
- Pyro_ppl_Pyro_TraceGraph_ELBO -- Graph-based ELBO for models without enumeration
- Pyro_ppl_Pyro_RenyiELBO -- Renyi divergence ELBO
- Pyro_ppl_Pyro_Infer_Utilities -- Contains Template:Code and Template:Code used by TMC
- Pyro_ppl_Pyro_TraceTailAdaptive_ELBO -- Tail-adaptive f-divergence objective