Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro CSIS

From Leeroopedia


Overview

The csis module (Template:Code) implements Compiled Sequential Importance Sampling (CSIS), an inference compilation technique that trains a guide (proposal distribution) to minimize the KL divergence from the model posterior to the guide: KL(model posterior || guide). This is in contrast to standard variational inference which minimizes KL(guide || model posterior).

CSIS works in two phases:

  1. Training phase -- The guide is trained by sampling from the model's joint distribution (without conditioning on observations), then computing the guide's log-probability of those samples. The loss is the negative expected log-probability of the guide under the model: -E_{p(x,y)} [log q(x,y)].
  2. Inference phase -- After training, importance sampling is performed using the compiled guide as the proposal distribution.

The CSIS class extends Importance (which extends TracePosterior), inheriting the ability to collect weighted traces and compute marginal distributions. CSIS is particularly useful for amortized inference, where the guide is trained once and then used for rapid inference on new observations.

The reference paper is "Inference Compilation and Universal Probabilistic Programming" (Le et al., 2017).

Code Reference

File: Template:Code

Key Classes

Class Parent Description
Template:Code Template:Code Compiled Sequential Importance Sampling. Provides training loop and importance sampling inference.

CSIS Methods

Method Description
Template:Code Initialize CSIS with model, guide, optimizer, and batch sizes.
Template:Code Take a single gradient step on the CSIS loss. Returns an estimate of the loss as a float.
Template:Code Compute the loss (and optionally gradients). Can use a provided batch of traces or generate fresh samples.
Template:Code Compute loss on a held-out validation batch. Sets the validation batch on first call.
Template:Code Pre-sample a batch of model traces for validation loss computation.
Template:Code Given a model trace, run the guide with observations from that trace to get a matched guide trace.
Template:Code Sample from the unconditioned model to get a joint sample trace.
Template:Code Compute the per-particle loss: -guide_trace.log_prob_sum().

I/O Contract

Constructor

Inputs:

step

Inputs:

Output:

loss_and_grads

Inputs:

Output:

Observation Passing Convention

Both the model and guide must accept an Template:Code keyword argument. During training, Template:Code populates this with sampled values from the model trace for sites that were originally observed (marked with Template:Code in their infer dict).

Usage Examples

Training and Inference

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import CSIS, EmpiricalMarginal

def model(observations=None):
    mu = pyro.sample("mu", dist.Normal(0, 10))
    sigma = pyro.sample("sigma", dist.LogNormal(0, 1))
    obs_val = observations.get("obs") if observations else None
    with pyro.plate("data", 10):
        pyro.sample("obs", dist.Normal(mu, sigma), obs=obs_val)

def guide(observations=None):
    # Amortized guide that conditions on observations
    obs = observations.get("obs") if observations else torch.zeros(10)
    obs_mean = obs.mean()
    loc = pyro.param("loc", torch.tensor(0.0))
    pyro.sample("mu", dist.Normal(loc + obs_mean, 1.0))
    pyro.sample("sigma", dist.LogNormal(0, 1))

optim = pyro.optim.Adam({"lr": 0.001})
csis = CSIS(model, guide, optim,
            num_inference_samples=100,
            training_batch_size=32)

# Training phase
for step in range(1000):
    loss = csis.step()
    if step % 100 == 0:
        print(f"Step {step}, loss: {loss:.4f}")

# Inference phase: run importance sampling
data = torch.randn(10) + 3.0  # observed data
csis.run(observations={"obs": data})

# Extract posterior marginal
marginal = EmpiricalMarginal(csis, sites="mu")
print("Posterior mean of mu:", marginal.mean)

Monitoring with Validation Loss

csis.set_validation_batch()
for step in range(1000):
    train_loss = csis.step()
    if step % 50 == 0:
        val_loss = csis.validation_loss()
        print(f"Step {step}: train={train_loss:.4f}, val={val_loss:.4f}")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment