Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Pyro ppl Pyro TraceTMC ELBO

From Leeroopedia


Overview

The tracetmc_elbo module (Template:Code) implements TraceTMC_ELBO, a trace-based implementation of Tensor Monte Carlo (TMC) as described by Aitchison (2018), built on top of Tensor Variable Elimination (TVE) from Obermeyer et al. (2019). This provides a powerful framework that combines:

  • Local parallel sampling -- Multiple samples drawn in parallel at any sample site in the model or guide.
  • Exhaustive enumeration -- Complete enumeration over discrete sample sites.

TMC generalizes standard importance-weighted estimators by treating samples as tensor dimensions and using einsum-based contraction (via Template:Code) to compute the marginal likelihood estimate. This is more efficient than naive approaches because it exploits conditional independence structure through tensor variable elimination.

To enable parallel sampling at a site, mark it with Template:Code}. To enumerate, mark with Template:Code} or Template:Code}. The Template:Code decorator can configure all sites at once.

Structural restriction: Variables outside of a Template:Code can never depend on variables inside that Template:Code.

The module also handles DiCE factors for non-reparameterizable proposal sites, enabling correct gradient estimation for discrete variables.

Code Reference

File: Template:Code

Key Classes

Class Parent Description
Template:Code Template:Code Tensor Monte Carlo ELBO using parallel sampling and exhaustive enumeration with tensor variable elimination.

TraceTMC_ELBO Methods

Method Description
Template:Code Compute a differentiable TMC estimate of the marginal log-likelihood. Returns a tensor.
Template:Code Compute the TMC loss as a float (under Template:Code).
Template:Code Compute the loss, perform backward, and return the loss as a float.
Template:Code Returns paired model/guide traces with packed tensors and plate-to-symbol mappings.
Template:Code Generator yielding traces with parallel enumeration enabled. Uses Template:Code for allocation of enumeration dimensions.

Internal Functions

Function Description
Template:Code Computes per-site log-factors: log(p/q) for unobserved sites, log(p) for observed sites, with corrections for sites sampled from the prior.
Template:Code Computes DiCE log-factors for non-reparameterizable proposal sites, enabling correct gradient estimation.
Template:Code Uses Template:Code with the Template:Code backend to perform tensor variable elimination and compute the TMC marginal likelihood estimate.

I/O Contract

Constructor (inherited from ELBO)

Inputs:

  • Template:Code -- Number of particles/samples. Default is 1.
  • Template:Code -- Max nested plate depth. Default is infinity (auto-detected).
  • Template:Code -- Whether to vectorize particles. Default is False.
  • Template:Code -- Warn if no enumerated sites found. Default is True.

differentiable_loss

Inputs:

Output:

  • Template:Code -- A differentiable negative TMC estimate of the marginal log-likelihood.

Raises:

  • Template:Code if the result is identically zero and not differentiable.

Site Configuration

To take multiple samples at a site:

pyro.sample("z", dist, infer={"enumerate": "parallel", "num_samples": 10})

To enumerate a site:

pyro.sample("z", dist, infer={"enumerate": "parallel"})

Internal TMC Computation

  1. Collect log-factors from model and guide (log p/q for latent sites, log p for observed sites).
  2. Collect DiCE factors for non-reparameterizable sites.
  3. Construct an einsum equation from the Template:Code annotations on each factor.
  4. Use Template:Code with Template:Code backend to contract, respecting plate structure.
  5. The result is a scalar TMC estimate.

Usage Examples

Basic TMC with Parallel Sampling

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, TraceTMC_ELBO, config_enumerate
from pyro.optim import Adam

@config_enumerate(num_samples=10)
def model(data):
    z = pyro.sample("z", dist.Normal(0, 1))
    pyro.sample("obs", dist.Normal(z, 1), obs=data)

def guide(data):
    loc = pyro.param("loc", torch.tensor(0.0))
    scale = pyro.param("scale", torch.tensor(1.0),
                       constraint=dist.constraints.positive)
    pyro.sample("z", dist.Normal(loc, scale))

tmc = TraceTMC_ELBO(num_particles=1)
svi = SVI(model, guide, Adam({"lr": 0.01}), loss=tmc)

for step in range(1000):
    loss = svi.step(data)

Enumeration over Discrete Variables

@config_enumerate
def model(data):
    # Discrete variable will be fully enumerated
    k = pyro.sample("k", dist.Categorical(torch.ones(3) / 3))
    locs = torch.tensor([-1.0, 0.0, 1.0])
    pyro.sample("obs", dist.Normal(locs[k], 1), obs=data)

def guide(data):
    probs = pyro.param("probs", torch.ones(3) / 3,
                       constraint=dist.constraints.simplex)
    pyro.sample("k", dist.Categorical(probs))

tmc = TraceTMC_ELBO()
svi = SVI(model, guide, Adam({"lr": 0.01}), loss=tmc)

Mixed Enumeration and Sampling

def model(data):
    # Enumerate this discrete site
    k = pyro.sample("k", dist.Categorical(torch.ones(3) / 3),
                     infer={"enumerate": "parallel"})
    # Take multiple samples from this continuous site
    z = pyro.sample("z", dist.Normal(0, 1),
                     infer={"enumerate": "parallel", "num_samples": 5})
    pyro.sample("obs", dist.Normal(z + k, 1), obs=data)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment