Implementation:Pyro ppl Pyro Predictive SVI
Metadata
| Field | Value |
|---|---|
| Implementation ID | Pyro_ppl_Pyro_Predictive_SVI |
| Title | Predictive (SVI Usage) |
| Project | Pyro (pyro-ppl/pyro) |
| File | pyro/infer/predictive.py, Lines 162-296
|
| Implements | Pyro_ppl_Pyro_Posterior_Predictive_Analysis |
| Repository | https://github.com/pyro-ppl/pyro |
Summary
The Predictive class (used with a guide for SVI) constructs the posterior predictive distribution by sampling latent variables from a trained guide and running the model conditioned on those samples. This is the primary tool for generating predictions after SVI training in Pyro.
Signature
class Predictive(torch.nn.Module):
def __init__(
self,
model,
posterior_samples=None,
guide=None,
num_samples=None,
return_sites=(),
parallel=False,
)
Import
from pyro.infer import Predictive
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
model |
callable | required | Python callable containing Pyro primitives (the generative model). |
posterior_samples |
dict or None | None | Dictionary of posterior samples. For SVI usage, this is typically None (samples come from the guide instead). |
guide |
callable or None | None | A trained guide (variational distribution). For SVI usage, pass the trained guide here. |
num_samples |
int or None | None | Number of samples to draw from the posterior predictive distribution. Required when using a guide. |
return_sites |
list, tuple, or set | () | Sites to include in the output. Default () returns all sites not in posterior_samples. Use None for all sites. Include "_RETURN" to capture the model's return value.
|
parallel |
bool | False | If True, wraps the model in an outer pyro.plate for vectorized prediction. Requires the model to have all batch dims annotated via pyro.plate.
|
Returns (forward method)
| Type | Description |
|---|---|
| dict | A dictionary mapping site names to tensors of shape (num_samples, ...). Each tensor contains num_samples draws from the posterior predictive distribution for that site.
|
SVI Usage Pattern
The typical SVI workflow for posterior prediction:
# 1. Define model and guide
def model(x, y=None):
...
guide = AutoNormal(model)
# 2. Train via SVI
svi = SVI(model, guide, Adam({"lr": 0.01}), loss=Trace_ELBO())
for step in range(num_steps):
svi.step(x_train, y_train)
# 3. Generate posterior predictive samples
predictive = Predictive(model, guide=guide, num_samples=1000)
samples = predictive(x_test)
# samples is a dict: {"obs": tensor(1000, ...), "weight": tensor(1000, ...), ...}
Internal Mechanism (SVI Path)
When a guide is provided, forward() (lines 254-289) follows this path:
Step 1: Sample from Guide
posterior_samples = _predictive(
self.guide,
{}, # empty posterior_samples
self.num_samples,
return_sites=None, # return ALL guide sites
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
).samples
This runs the guide num_samples times (or once in parallel) and collects all latent variable samples.
Step 2: Run Model Conditioned on Guide Samples
return _predictive(
self.model,
posterior_samples,
self.num_samples,
return_sites=return_sites,
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
).samples
The model is run conditioned on the guide's samples via poutine.condition(model, posterior_samples).
The _predictive Helper (lines 79-159)
This internal function handles the actual prediction mechanics:
- Guesses
max_plate_nestingby running the model once - Reshapes posterior samples to align with plate structure
- Determines which sites to return based on
return_sites - If
parallel=True: wraps model inpyro.plate("_num_predictive_samples", num_samples)and runs once - If
parallel=False: calls_predictive_sequentialwhich loopsnum_samplestimes
Controlling Return Sites
return_sites |
Behavior |
|---|---|
() (default) |
Returns all sample sites NOT in posterior_samples (typically observed sites) |
None |
Returns ALL sample sites (both latent and observed). Automatically used when a guide is provided and return_sites is empty. |
("obs", "weight") |
Returns only the specified sites |
("_RETURN",) |
Includes the model's return value in the output |
Complete Example
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, Predictive
from pyro.optim import Adam
# Bayesian neural network (simplified)
def model(x, y=None):
w1 = pyro.sample("w1", dist.Normal(torch.zeros(1, 10), torch.ones(1, 10)).to_event(2))
b1 = pyro.sample("b1", dist.Normal(torch.zeros(10), torch.ones(10)).to_event(1))
w2 = pyro.sample("w2", dist.Normal(torch.zeros(10, 1), torch.ones(10, 1)).to_event(2))
b2 = pyro.sample("b2", dist.Normal(torch.zeros(1), torch.ones(1)).to_event(1))
sigma = pyro.sample("sigma", dist.LogNormal(0, 1))
hidden = torch.tanh(x @ w1 + b1)
mean = (hidden @ w2 + b2).squeeze(-1)
with pyro.plate("data", x.shape[0]):
pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
return mean
# Train
from pyro.infer.autoguide import AutoDiagonalNormal
guide = AutoDiagonalNormal(model)
svi = SVI(model, guide, Adam({"lr": 0.005}), Trace_ELBO())
x_train = torch.randn(200, 1)
y_train = torch.sin(x_train).squeeze() + 0.1 * torch.randn(200)
for step in range(2000):
svi.step(x_train, y_train)
# Predict -- SVI-specific: pass guide, no posterior_samples
predictive = Predictive(
model,
guide=guide,
num_samples=500,
return_sites=("obs", "_RETURN"),
)
x_test = torch.linspace(-3, 3, 100).unsqueeze(-1)
preds = predictive(x_test)
# preds["obs"]: shape (500, 100) -- posterior predictive samples
# preds["_RETURN"]: shape (500, 100) -- mean predictions
mean = preds["obs"].mean(0)
std = preds["obs"].std(0)
Validation
The __init__ method (lines 188-237) performs the following validation:
- Either
posterior_samplesornum_samplesmust be provided - If both
guideandposterior_samplesare provided, raisesValueError - If
posterior_sampleshas mismatched leading dimensions, warns and uses the batch size
Related Pages
Implements Principle
Related Implementations
- Pyro_ppl_Pyro_Predictive_MCMC -- The same
Predictiveclass used withposterior_samplesfrom MCMC instead of a guide. - Pyro_ppl_Pyro_VAE_Encoder_Decoder_Pattern -- In VAE models, the guide is the encoder and Predictive generates reconstructions through the decoder.
- Environment:Pyro_ppl_Pyro_Visualization_Tools