Implementation:Pyro ppl Pyro Resampler
Overview
The resampler module (Template:Code) provides the Resampler class, a computational cache designed for interactive tuning of generative models during prior predictive checks as an early step in Bayesian workflow.
The core idea is that running a simulation (model) can be expensive, but when slightly tweaking parameters of a prior distribution, most of the previous samples can be reused via importance resampling. The Resampler avoids re-running the simulation by:
- Drawing a large batch of samples from a diffuse guide distribution once during initialization.
- Optionally extending those samples through a simulator to produce downstream quantities.
- On each call to Template:Code, computing importance weights between the new model and the original guide, then resampling from the cached samples proportional to those weights.
The Resampler supports stable categorical sampling via the Gumbel-max trick, which provides piecewise-constant resampling suitable for visualization (the same particles are returned when parameters change only slightly). This can be toggled off for standard Monte Carlo integration.
Code Reference
File: Template:Code
Key Classes
| Class | Parent | Description |
|---|---|---|
| Template:Code | -- | Interactive resampler for prior predictive checks with importance resampling. |
Resampler Methods
| Method | Description |
|---|---|
| Template:Code | Initialize by drawing samples from the guide (and optionally the simulator). |
| Template:Code | Draw at most Template:Code from the cached samples by importance resampling against the model. |
Internal Helpers
| Function | Description |
|---|---|
| Template:Code | Computes vectorized log_prob_sum batched over the leftmost dimension. |
| Template:Code | Automatically detects max plate nesting by tracing the model. |
I/O Contract
Constructor
Inputs:
- Template:Code -- A Pyro model with no required arguments. Must be diffuse, covering more space than the subsequent model. Must be vectorizable via Template:Code.
- Template:Code -- An optional larger model with a superset of the guide's latent variables. Used to extend guide samples with downstream quantities. Must be vectorizable.
- Template:Code -- Number of initial samples from the guide. Should be much larger than later Template:Code requests.
- Template:Code -- Maximum plate nesting. If absent, auto-detected.
sample
Inputs:
- Template:Code -- A model with the same latent variables as the guide. Must be vectorizable.
- Template:Code -- Number of samples to draw.
- Template:Code -- Whether to use Gumbel-max stable resampling (default True). Use True for visualization, False for Monte Carlo integration.
Output:
- Template:Code -- Dictionary mapping sample site names to tensors of shape Template:Code.
Internal State
- Template:Code -- Cached samples from the guide (and simulator), keyed by site name.
- Template:Code -- Log-probability of the cached samples under the original guide.
- Template:Code -- Cached Gumbel noise for stable resampling.
- Template:Code -- The plate dimension used for particle vectorization.
Usage Examples
Basic Prior Predictive Tuning
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import Resampler
# A diffuse guide covering the parameter space
def guide():
mu = pyro.sample("mu", dist.Normal(0, 100))
sigma = pyro.sample("sigma", dist.LogNormal(0, 10))
# The model we want to tune
def model():
mu = pyro.sample("mu", dist.Normal(0, 10))
sigma = pyro.sample("sigma", dist.LogNormal(0, 1))
# Initialize resampler with many guide samples
resampler = Resampler(guide, num_guide_samples=10000)
# Quickly get model-weighted samples without re-running the simulation
samples = resampler.sample(model, num_samples=100)
print("mu samples:", samples["mu"])
print("sigma samples:", samples["sigma"])
With a Simulator
def guide():
mu = pyro.sample("mu", dist.Normal(0, 100))
def simulator():
mu = pyro.sample("mu", dist.Normal(0, 100))
# Expensive downstream computation
x = pyro.sample("x", dist.Normal(mu, 1))
def model():
mu = pyro.sample("mu", dist.Normal(5, 2))
resampler = Resampler(guide, simulator, num_guide_samples=50000)
samples = resampler.sample(model, num_samples=200)
# samples contains both "mu" and "x"
print("x samples:", samples["x"])
Stable vs Unstable Resampling
# Stable resampling for interactive visualization
# (same samples returned for similar parameters)
vis_samples = resampler.sample(model, num_samples=50, stable=True)
# Unstable (multinomial) resampling for Monte Carlo estimates
mc_samples = resampler.sample(model, num_samples=1000, stable=False)
Related Pages
- Pyro_ppl_Pyro_Importance -- Importance sampling which Resampler uses conceptually (importance reweighting)
- Pyro_ppl_Pyro_SMCFilter -- Sequential Monte Carlo with importance resampling for time series
- Pyro_ppl_Pyro_Abstract_Infer -- Base classes for posterior inference