Implementation:Pyro ppl Pyro Predictive MCMC
Metadata
| Field | Value |
|---|---|
| Implementation ID | Pyro_ppl_Pyro_Predictive_MCMC |
| Title | Predictive (MCMC Usage) |
| Project | Pyro (pyro-ppl/pyro) |
| File | pyro/infer/predictive.py, Lines 162-296
|
| Implements | Pyro_ppl_Pyro_MCMC_Posterior_Prediction |
| Repository | https://github.com/pyro-ppl/pyro |
Summary
The Predictive class (used with posterior_samples for MCMC) constructs the posterior predictive distribution by replaying the model with explicit posterior samples from MCMC chains. Unlike the SVI usage which requires a guide, the MCMC usage directly provides posterior samples obtained from MCMC.get_samples().
Signature (MCMC Focus)
class Predictive(torch.nn.Module):
def __init__(
self,
model,
posterior_samples=None, # <-- MCMC: pass mcmc.get_samples() here
guide=None, # <-- MCMC: leave as None
num_samples=None,
return_sites=(),
parallel=False,
)
Import
from pyro.infer import Predictive
Parameters (MCMC-Specific)
| Parameter | Type | Default | Description |
|---|---|---|---|
model |
callable | required | Python callable containing Pyro primitives (the generative model). |
posterior_samples |
dict | required for MCMC | Dictionary of samples from the posterior, typically from mcmc.get_samples(). Keys are site names, values are tensors with shape (num_samples, ...).
|
guide |
None | None | Must be None for MCMC usage. Cannot provide both guide and posterior_samples. |
num_samples |
int or None | None | Automatically inferred from the leading dimension of posterior_samples tensors. If provided, must match. |
return_sites |
list, tuple, or set | () | Sites to include in the output. Default () returns all sites NOT in posterior_samples.
|
parallel |
bool | False | If True, vectorizes prediction by wrapping model in an outer pyro.plate.
|
Returns (forward method)
| Type | Description |
|---|---|
| dict | Dictionary mapping site names to tensors of shape (num_samples, ...). Contains posterior predictive samples for the requested sites.
|
MCMC Usage Pattern
from pyro.infer import MCMC, NUTS, Predictive
# 1. Run MCMC
nuts = NUTS(model)
mcmc = MCMC(nuts, num_samples=1000, warmup_steps=500)
mcmc.run(x_train, y_train)
# 2. Extract posterior samples
posterior_samples = mcmc.get_samples()
# 3. Create Predictive -- note: guide=None, use posterior_samples
predictive = Predictive(
model,
posterior_samples=posterior_samples,
return_sites=("obs",),
)
# 4. Generate posterior predictive samples
preds = predictive(x_test)
Internal Mechanism (MCMC Path)
When posterior_samples is provided and guide is None, the forward() method (lines 254-289) takes the direct path:
def forward(self, *args, **kwargs):
posterior_samples = self.posterior_samples
return_sites = self.return_sites
# guide is None, so skip guide sampling
# go directly to model prediction
return _predictive(
self.model,
posterior_samples, # directly from MCMC
self.num_samples,
return_sites=return_sites,
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
).samples
Sequential Mode (parallel=False)
The _predictive_sequential function (lines 50-73) loops over each MCMC sample:
def _predictive_sequential(model, posterior_samples, model_args, model_kwargs,
num_samples, return_site_shapes):
samples = [{k: v[i] for k, v in posterior_samples.items()}
for i in range(num_samples)]
for i in range(num_samples):
trace = poutine.trace(
poutine.condition(model, samples[i])
).get_trace(*model_args, **model_kwargs)
# collect site values from trace
# stack and reshape collected samples
Parallel Mode (parallel=True)
The _predictive function (lines 79-159):
- Reshapes posterior samples:
(num_samples,) + (1,) * plate_padding + sample_shape - Wraps the model in
pyro.plate("_num_predictive_samples", num_samples) - Conditions model on reshaped samples
- Runs once and collects all outputs
Key Difference from SVI Usage
| Aspect | SVI Usage | MCMC Usage |
|---|---|---|
guide parameter |
Trained guide callable | None |
posterior_samples |
None (generated from guide) | mcmc.get_samples()
|
num_samples |
Required (user-specified) | Inferred from posterior_samples shape |
| Internal flow | Guide sampled first, then model conditioned | Model conditioned directly on posterior_samples |
num_samples Inference
When posterior_samples is provided, num_samples is automatically determined from the leading dimension of the sample tensors (lines 205-217):
for name, sample in posterior_samples.items():
batch_size = sample.shape[0]
if num_samples is None:
num_samples = batch_size
elif num_samples != batch_size:
warnings.warn(
"Sample's leading dimension size {} is different from the "
"provided {} num_samples argument.".format(batch_size, num_samples)
)
num_samples = batch_size
Complete Example
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS, Predictive
# Hierarchical model
def model(x, y=None):
# Hyperpriors
mu_weight = pyro.sample("mu_weight", dist.Normal(0, 10))
sigma_weight = pyro.sample("sigma_weight", dist.HalfNormal(5))
# Group-level parameters
weight = pyro.sample("weight", dist.Normal(mu_weight, sigma_weight))
bias = pyro.sample("bias", dist.Normal(0, 10))
sigma = pyro.sample("sigma", dist.HalfNormal(5))
mean = weight * x + bias
with pyro.plate("data", len(x)):
pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
# Run MCMC
x_train = torch.randn(100)
y_train = 3.0 * x_train + 1.5 + 0.5 * torch.randn(100)
nuts = NUTS(model)
mcmc = MCMC(nuts, num_samples=500, warmup_steps=300)
mcmc.run(x_train, y_train)
# Posterior samples from MCMC
posterior_samples = mcmc.get_samples()
print({k: v.shape for k, v in posterior_samples.items()})
# {'mu_weight': (500,), 'sigma_weight': (500,), 'weight': (500,),
# 'bias': (500,), 'sigma': (500,)}
# MCMC-based posterior predictive
predictive = Predictive(
model,
posterior_samples=posterior_samples,
return_sites=("obs",),
parallel=True,
)
x_new = torch.linspace(-2, 2, 200)
preds = predictive(x_new)
# Analyze uncertainty
obs_samples = preds["obs"] # shape: (500, 200)
credible_interval = torch.quantile(obs_samples, torch.tensor([0.05, 0.95]), dim=0)
Related Pages
Implements Principle
Related Implementations
- Pyro_ppl_Pyro_Predictive_SVI -- The same
Predictiveclass used with a guide instead of explicit posterior samples. - Pyro_ppl_Pyro_Infer_Discrete -- For models with discrete latents,
infer_discretecan be used to decode discrete states before or after MCMC prediction. - Environment:Pyro_ppl_Pyro_Visualization_Tools