Implementation:Pyro ppl Pyro Importance
Overview
The importance module (Template:Code) implements Importance Sampling for posterior inference in Pyro. It provides the Importance class for basic sequential importance sampling, the LogWeightsMixin for computing importance weight analytics, a vectorized_importance_weights function for efficient vectorized computation of importance weights, and a psis_diagnostic function for evaluating guide quality via Pareto Smoothed Importance Sampling.
Importance extends TracePosterior and performs posterior inference by drawing samples from the guide (proposal distribution) and weighting them by the ratio of model to guide log-probabilities. If no guide is provided, it defaults to proposing from the model's prior.
LogWeightsMixin adds convenience methods for computing the log normalizing constant, normalized importance weights, and Effective Sample Size (ESS) from a list of log-weights.
vectorized_importance_weights provides an efficient vectorized alternative for models with static structure, computing importance weights across many particles in a single forward pass using Pyro plates.
psis_diagnostic computes the Pareto tail index k, which assesses how well the guide approximates the posterior: k < 0.5 indicates a good fit, 0.5 <= k <= 0.7 is acceptable, and k > 0.7 suggests a poor fit.
Code Reference
File: Template:Code
Key Classes
| Class | Parent(s) | Description |
|---|---|---|
| Template:Code | Template:Code, Template:Code | Sequential importance sampling using the guide as proposal distribution. |
| Template:Code | -- | Mixin providing analytics from a Template:Code attribute. |
Importance Methods
| Method | Description |
|---|---|
| Template:Code | Initialize with model, optional guide (defaults to prior), and number of samples (default 10). |
| Template:Code | Generator yielding Template:Code pairs for each sample. |
LogWeightsMixin Methods
| Method | Description |
|---|---|
| Template:Code | Returns the log of the mean unnormalized importance weight, estimating log Z. |
| Template:Code | Returns normalized importance weights. If Template:Code, returns log-normalized weights. |
| Template:Code | Returns the Effective Sample Size computed from normalized weights. |
Standalone Functions
| Function | Description |
|---|---|
| Template:Code | Vectorized computation of importance weights for static-structure models. Keyword args: Template:Code, Template:Code, Template:Code. |
| Template:Code | Computes the Pareto tail index k diagnostic for a model/guide pair. Keyword args: Template:Code, Template:Code, Template:Code. |
I/O Contract
Importance Constructor
Inputs:
- Template:Code -- Probabilistic model.
- Template:Code -- Proposal distribution. If Template:Code, uses the model's prior (by blocking observe sites).
- Template:Code -- Number of importance samples (default 10, with a warning).
Importance._traces
Inputs:
- Template:Code -- Passed to model and guide.
Yields:
- Template:Code -- Each model trace paired with its log importance weight (log p - log q).
vectorized_importance_weights
Required Keyword Args:
- Template:Code -- Must be provided; raises ValueError otherwise.
Optional Keyword Args:
- Template:Code -- Number of samples (default 1).
- Template:Code -- Whether to normalize weights (default False).
Output:
- Tuple of Template:Code where Template:Code is a Template:Code-shaped tensor.
psis_diagnostic
Optional Keyword Args:
- Template:Code -- Total number of model/guide evaluations (default 1000).
- Template:Code -- Max batch size (default Template:Code). Template:Code must be divisible by this.
- Template:Code -- Max nested plate depth (default 7).
Output:
- Template:Code -- The Pareto tail index k.
Usage Examples
Basic Importance Sampling
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import Importance, EmpiricalMarginal
def model():
z = pyro.sample("z", dist.Normal(0, 1))
pyro.sample("x", dist.Normal(z, 0.5), obs=torch.tensor(2.0))
def guide():
loc = pyro.param("loc", torch.tensor(0.0))
pyro.sample("z", dist.Normal(loc, 1.0))
posterior = Importance(model, guide, num_samples=1000)
posterior.run()
marginal = EmpiricalMarginal(posterior, "z")
print("Mean:", marginal.mean)
print("ESS:", posterior.get_ESS())
print("Log normalizer:", posterior.get_log_normalizer())
Prior-Based Importance Sampling
# When no guide is provided, samples from the prior
posterior = Importance(model, num_samples=5000)
posterior.run()
Vectorized Importance Weights
from pyro.infer.importance import vectorized_importance_weights
log_weights, model_trace, guide_trace = vectorized_importance_weights(
model, guide,
num_samples=1000,
max_plate_nesting=4,
normalized=False
)
print("Log weights shape:", log_weights.shape) # (1000,)
PSIS Diagnostic
from pyro.infer.importance import psis_diagnostic
k = psis_diagnostic(model, guide,
num_particles=5000,
max_plate_nesting=4)
if k < 0.5:
print(f"Good guide fit (k={k:.3f})")
elif k < 0.7:
print(f"Acceptable guide fit (k={k:.3f})")
else:
print(f"Poor guide fit (k={k:.3f})")
Related Pages
- Pyro_ppl_Pyro_Abstract_Infer -- TracePosterior base class that Importance inherits from
- Pyro_ppl_Pyro_CSIS -- Compiled Sequential Importance Sampling, extends Importance
- Pyro_ppl_Pyro_SMCFilter -- Sequential Monte Carlo using importance resampling
- Pyro_ppl_Pyro_ReweightedWakeSleep -- Uses importance weights for wake-theta and wake-phi losses
- Pyro_ppl_Pyro_Infer_Utilities -- Utility functions used during importance weight computation