Implementation:Pyro ppl Pyro CSIS
Overview
The csis module (Template:Code) implements Compiled Sequential Importance Sampling (CSIS), an inference compilation technique that trains a guide (proposal distribution) to minimize the KL divergence from the model posterior to the guide: KL(model posterior || guide). This is in contrast to standard variational inference which minimizes KL(guide || model posterior).
CSIS works in two phases:
- Training phase -- The guide is trained by sampling from the model's joint distribution (without conditioning on observations), then computing the guide's log-probability of those samples. The loss is the negative expected log-probability of the guide under the model: -E_{p(x,y)} [log q(x,y)].
- Inference phase -- After training, importance sampling is performed using the compiled guide as the proposal distribution.
The CSIS class extends Importance (which extends TracePosterior), inheriting the ability to collect weighted traces and compute marginal distributions. CSIS is particularly useful for amortized inference, where the guide is trained once and then used for rapid inference on new observations.
The reference paper is "Inference Compilation and Universal Probabilistic Programming" (Le et al., 2017).
Code Reference
File: Template:Code
Key Classes
| Class | Parent | Description |
|---|---|---|
| Template:Code | Template:Code | Compiled Sequential Importance Sampling. Provides training loop and importance sampling inference. |
CSIS Methods
| Method | Description |
|---|---|
| Template:Code | Initialize CSIS with model, guide, optimizer, and batch sizes. |
| Template:Code | Take a single gradient step on the CSIS loss. Returns an estimate of the loss as a float. |
| Template:Code | Compute the loss (and optionally gradients). Can use a provided batch of traces or generate fresh samples. |
| Template:Code | Compute loss on a held-out validation batch. Sets the validation batch on first call. |
| Template:Code | Pre-sample a batch of model traces for validation loss computation. |
| Template:Code | Given a model trace, run the guide with observations from that trace to get a matched guide trace. |
| Template:Code | Sample from the unconditioned model to get a joint sample trace. |
| Template:Code | Compute the per-particle loss: -guide_trace.log_prob_sum(). |
I/O Contract
Constructor
Inputs:
- Template:Code -- Probabilistic model. Must accept a keyword argument Template:Code where observed values are passed as a dict.
- Template:Code -- Guide function for the approximate posterior. Must also accept Template:Code keyword argument.
- Template:Code -- A Pyro optimizer.
- Template:Code -- Number of importance-weighted samples during inference (default 10).
- Template:Code -- Number of samples per training gradient step (default 10).
- Template:Code -- Number of samples for validation loss (default 20).
step
Inputs:
- Template:Code -- Passed to the model and guide.
Output:
- Template:Code -- Estimate of the training loss.
loss_and_grads
Inputs:
- Template:Code -- Whether to compute gradients.
- Template:Code -- Optional pre-sampled batch of model traces. If Template:Code, fresh samples are drawn.
- Template:Code -- Passed to model and guide.
Output:
- Template:Code -- Estimate of the loss: E_{p(x,y)} [-log q(x,y)].
Observation Passing Convention
Both the model and guide must accept an Template:Code keyword argument. During training, Template:Code populates this with sampled values from the model trace for sites that were originally observed (marked with Template:Code in their infer dict).
Usage Examples
Training and Inference
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import CSIS, EmpiricalMarginal
def model(observations=None):
mu = pyro.sample("mu", dist.Normal(0, 10))
sigma = pyro.sample("sigma", dist.LogNormal(0, 1))
obs_val = observations.get("obs") if observations else None
with pyro.plate("data", 10):
pyro.sample("obs", dist.Normal(mu, sigma), obs=obs_val)
def guide(observations=None):
# Amortized guide that conditions on observations
obs = observations.get("obs") if observations else torch.zeros(10)
obs_mean = obs.mean()
loc = pyro.param("loc", torch.tensor(0.0))
pyro.sample("mu", dist.Normal(loc + obs_mean, 1.0))
pyro.sample("sigma", dist.LogNormal(0, 1))
optim = pyro.optim.Adam({"lr": 0.001})
csis = CSIS(model, guide, optim,
num_inference_samples=100,
training_batch_size=32)
# Training phase
for step in range(1000):
loss = csis.step()
if step % 100 == 0:
print(f"Step {step}, loss: {loss:.4f}")
# Inference phase: run importance sampling
data = torch.randn(10) + 3.0 # observed data
csis.run(observations={"obs": data})
# Extract posterior marginal
marginal = EmpiricalMarginal(csis, sites="mu")
print("Posterior mean of mu:", marginal.mean)
Monitoring with Validation Loss
csis.set_validation_batch()
for step in range(1000):
train_loss = csis.step()
if step % 50 == 0:
val_loss = csis.validation_loss()
print(f"Step {step}: train={train_loss:.4f}, val={val_loss:.4f}")
Related Pages
- Pyro_ppl_Pyro_Importance -- Parent class providing importance sampling infrastructure
- Pyro_ppl_Pyro_Abstract_Infer -- TracePosterior base class that CSIS ultimately inherits from
- Pyro_ppl_Pyro_ReweightedWakeSleep -- Another inference method that trains guides using model samples (sleep phase)
- Pyro_ppl_Pyro_Infer_Utilities -- Utility functions used by CSIS