Principle:Pyro ppl Pyro MCMC Posterior Prediction
Metadata
| Field | Value |
|---|---|
| Principle ID | Pyro_ppl_Pyro_MCMC_Posterior_Prediction |
| Title | MCMC Posterior Prediction |
| Project | Pyro (pyro-ppl/pyro) |
| Domains | MCMC, Bayesian_Inference, Prediction |
| Implementation | Pyro_ppl_Pyro_Predictive_MCMC |
| Repository | https://github.com/pyro-ppl/pyro |
Summary
MCMC Posterior Prediction is the principle of generating posterior predictive samples using explicit posterior samples obtained from Markov Chain Monte Carlo (MCMC) methods. Unlike SVI (which approximates the posterior with a parametric guide), MCMC provides a collection of samples that are (asymptotically) drawn from the true posterior. These samples are replayed through the model to generate predictions.
Motivation
MCMC methods such as Hamiltonian Monte Carlo (HMC) and the No-U-Turn Sampler (NUTS) provide asymptotically exact posterior samples. This means that:
- There is no approximation gap (unlike SVI with a parametric guide)
- The posterior samples faithfully represent multimodality, skewness, and correlations
- Predictions based on MCMC samples provide calibrated uncertainty quantification
The trade-off is that MCMC is typically slower than SVI, especially for large datasets. However, for problems where accuracy of uncertainty estimates is paramount (e.g., medical decision-making, safety-critical systems), MCMC-based prediction is preferred.
Core Concepts
MCMC vs. SVI for Prediction
| Aspect | SVI | MCMC |
|---|---|---|
| Posterior source | Parametric guide q(z) | Samples from MCMC chains |
| Input to Predictive | guide=trained_guide |
posterior_samples=mcmc.get_samples()
|
| Approximation quality | Limited by guide family | Asymptotically exact |
| Speed | Fast (amortized) | Slow (chain must converge) |
| Scalability | Scales to large data | Challenging for large data |
The posterior_samples Dictionary
When using MCMC, the posterior is represented as a dictionary mapping site names to tensors:
posterior_samples = mcmc.get_samples()
# Example:
# {
# "weight": tensor of shape (num_samples,),
# "bias": tensor of shape (num_samples,),
# "sigma": tensor of shape (num_samples,),
# }
Each tensor's leading dimension corresponds to the number of MCMC samples. The Predictive class iterates over these samples, conditioning the model on each one.
Prediction Workflow
- Run MCMC to collect posterior samples:
mcmc.run(data) - Extract samples:
posterior_samples = mcmc.get_samples() - Create
Predictivewithposterior_samples(no guide needed) - Call
Predictivewith new data to get posterior predictive draws
No Guide Required
The key distinction from SVI-based prediction is that MCMC prediction does not need a guide. The posterior samples are explicit -- they are tensors of parameter values collected during MCMC sampling. The Predictive class directly conditions the model on each sample.
How It Works
For each posterior sample z_i in the MCMC chain:
- The model is conditioned on z_i using
poutine.condition(model, {site: value[i] for site, value in posterior_samples.items()}) - The conditioned model is run with the new input data
- All sample sites not in posterior_samples (typically observed/predicted variables) are recorded
- Results are stacked across all MCMC samples to form the posterior predictive distribution
When parallel=True, all samples are batched using pyro.plate and the model is run once.
Example
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS, Predictive
# Bayesian linear regression
def model(x, y=None):
weight = pyro.sample("weight", dist.Normal(0, 10))
bias = pyro.sample("bias", dist.Normal(0, 10))
sigma = pyro.sample("sigma", dist.HalfNormal(5))
mean = weight * x + bias
with pyro.plate("data", len(x)):
pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
# Generate data
x_train = torch.linspace(0, 1, 50)
y_train = 2.5 * x_train + 1.0 + 0.3 * torch.randn(50)
# Run MCMC
nuts = NUTS(model)
mcmc = MCMC(nuts, num_samples=500, warmup_steps=200)
mcmc.run(x_train, y_train)
# Get posterior samples
posterior_samples = mcmc.get_samples()
# {"weight": shape(500,), "bias": shape(500,), "sigma": shape(500,)}
# Generate posterior predictive samples (MCMC approach: no guide)
predictive = Predictive(
model,
posterior_samples=posterior_samples,
return_sites=("obs",),
)
x_test = torch.linspace(-0.5, 1.5, 100)
preds = predictive(x_test)
# preds["obs"]: shape (500, 100)
# Posterior predictive intervals
mean_pred = preds["obs"].mean(dim=0)
lower_95 = preds["obs"].quantile(0.025, dim=0)
upper_95 = preds["obs"].quantile(0.975, dim=0)
Relationship to Other Principles
- Pyro_ppl_Pyro_Posterior_Predictive_Analysis -- The SVI-based variant of posterior prediction. Both share the same
Predictiveclass but differ in how posterior samples are obtained. - Pyro_ppl_Pyro_Amortized_Variational_Inference -- An alternative to MCMC for posterior approximation that uses neural networks for amortized inference, trading accuracy for speed.
Related Pages
Implemented By
References
- Gelman, A. et al., "Bayesian Data Analysis", 3rd edition, Chapter 6
- Pyro MCMC tutorial: https://pyro.ai/examples/bayesian_regression.html
- Hoffman, M.D. & Gelman, A., "The No-U-Turn Sampler", 2014 (https://arxiv.org/abs/1111.4246)