Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro Resampler

From Leeroopedia


Overview

The resampler module (Template:Code) provides the Resampler class, a computational cache designed for interactive tuning of generative models during prior predictive checks as an early step in Bayesian workflow.

The core idea is that running a simulation (model) can be expensive, but when slightly tweaking parameters of a prior distribution, most of the previous samples can be reused via importance resampling. The Resampler avoids re-running the simulation by:

  1. Drawing a large batch of samples from a diffuse guide distribution once during initialization.
  2. Optionally extending those samples through a simulator to produce downstream quantities.
  3. On each call to Template:Code, computing importance weights between the new model and the original guide, then resampling from the cached samples proportional to those weights.

The Resampler supports stable categorical sampling via the Gumbel-max trick, which provides piecewise-constant resampling suitable for visualization (the same particles are returned when parameters change only slightly). This can be toggled off for standard Monte Carlo integration.

Code Reference

File: Template:Code

Key Classes

Class Parent Description
Template:Code -- Interactive resampler for prior predictive checks with importance resampling.

Resampler Methods

Method Description
Template:Code Initialize by drawing samples from the guide (and optionally the simulator).
Template:Code Draw at most Template:Code from the cached samples by importance resampling against the model.

Internal Helpers

Function Description
Template:Code Computes vectorized log_prob_sum batched over the leftmost dimension.
Template:Code Automatically detects max plate nesting by tracing the model.

I/O Contract

Constructor

Inputs:

  • Template:Code -- A Pyro model with no required arguments. Must be diffuse, covering more space than the subsequent model. Must be vectorizable via Template:Code.
  • Template:Code -- An optional larger model with a superset of the guide's latent variables. Used to extend guide samples with downstream quantities. Must be vectorizable.
  • Template:Code -- Number of initial samples from the guide. Should be much larger than later Template:Code requests.
  • Template:Code -- Maximum plate nesting. If absent, auto-detected.

sample

Inputs:

  • Template:Code -- A model with the same latent variables as the guide. Must be vectorizable.
  • Template:Code -- Number of samples to draw.
  • Template:Code -- Whether to use Gumbel-max stable resampling (default True). Use True for visualization, False for Monte Carlo integration.

Output:

Internal State

  • Template:Code -- Cached samples from the guide (and simulator), keyed by site name.
  • Template:Code -- Log-probability of the cached samples under the original guide.
  • Template:Code -- Cached Gumbel noise for stable resampling.
  • Template:Code -- The plate dimension used for particle vectorization.

Usage Examples

Basic Prior Predictive Tuning

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import Resampler

# A diffuse guide covering the parameter space
def guide():
    mu = pyro.sample("mu", dist.Normal(0, 100))
    sigma = pyro.sample("sigma", dist.LogNormal(0, 10))

# The model we want to tune
def model():
    mu = pyro.sample("mu", dist.Normal(0, 10))
    sigma = pyro.sample("sigma", dist.LogNormal(0, 1))

# Initialize resampler with many guide samples
resampler = Resampler(guide, num_guide_samples=10000)

# Quickly get model-weighted samples without re-running the simulation
samples = resampler.sample(model, num_samples=100)
print("mu samples:", samples["mu"])
print("sigma samples:", samples["sigma"])

With a Simulator

def guide():
    mu = pyro.sample("mu", dist.Normal(0, 100))

def simulator():
    mu = pyro.sample("mu", dist.Normal(0, 100))
    # Expensive downstream computation
    x = pyro.sample("x", dist.Normal(mu, 1))

def model():
    mu = pyro.sample("mu", dist.Normal(5, 2))

resampler = Resampler(guide, simulator, num_guide_samples=50000)
samples = resampler.sample(model, num_samples=200)
# samples contains both "mu" and "x"
print("x samples:", samples["x"])

Stable vs Unstable Resampling

# Stable resampling for interactive visualization
# (same samples returned for similar parameters)
vis_samples = resampler.sample(model, num_samples=50, stable=True)

# Unstable (multinomial) resampling for Monte Carlo estimates
mc_samples = resampler.sample(model, num_samples=1000, stable=False)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment