Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro Predictive MCMC

From Leeroopedia


Metadata

Field Value
Implementation ID Pyro_ppl_Pyro_Predictive_MCMC
Title Predictive (MCMC Usage)
Project Pyro (pyro-ppl/pyro)
File pyro/infer/predictive.py, Lines 162-296
Implements Pyro_ppl_Pyro_MCMC_Posterior_Prediction
Repository https://github.com/pyro-ppl/pyro

Summary

The Predictive class (used with posterior_samples for MCMC) constructs the posterior predictive distribution by replaying the model with explicit posterior samples from MCMC chains. Unlike the SVI usage which requires a guide, the MCMC usage directly provides posterior samples obtained from MCMC.get_samples().

Signature (MCMC Focus)

class Predictive(torch.nn.Module):
    def __init__(
        self,
        model,
        posterior_samples=None,   # <-- MCMC: pass mcmc.get_samples() here
        guide=None,               # <-- MCMC: leave as None
        num_samples=None,
        return_sites=(),
        parallel=False,
    )

Import

from pyro.infer import Predictive

Parameters (MCMC-Specific)

Parameter Type Default Description
model callable required Python callable containing Pyro primitives (the generative model).
posterior_samples dict required for MCMC Dictionary of samples from the posterior, typically from mcmc.get_samples(). Keys are site names, values are tensors with shape (num_samples, ...).
guide None None Must be None for MCMC usage. Cannot provide both guide and posterior_samples.
num_samples int or None None Automatically inferred from the leading dimension of posterior_samples tensors. If provided, must match.
return_sites list, tuple, or set () Sites to include in the output. Default () returns all sites NOT in posterior_samples.
parallel bool False If True, vectorizes prediction by wrapping model in an outer pyro.plate.

Returns (forward method)

Type Description
dict Dictionary mapping site names to tensors of shape (num_samples, ...). Contains posterior predictive samples for the requested sites.

MCMC Usage Pattern

from pyro.infer import MCMC, NUTS, Predictive

# 1. Run MCMC
nuts = NUTS(model)
mcmc = MCMC(nuts, num_samples=1000, warmup_steps=500)
mcmc.run(x_train, y_train)

# 2. Extract posterior samples
posterior_samples = mcmc.get_samples()

# 3. Create Predictive -- note: guide=None, use posterior_samples
predictive = Predictive(
    model,
    posterior_samples=posterior_samples,
    return_sites=("obs",),
)

# 4. Generate posterior predictive samples
preds = predictive(x_test)

Internal Mechanism (MCMC Path)

When posterior_samples is provided and guide is None, the forward() method (lines 254-289) takes the direct path:

def forward(self, *args, **kwargs):
    posterior_samples = self.posterior_samples
    return_sites = self.return_sites
    # guide is None, so skip guide sampling
    # go directly to model prediction
    return _predictive(
        self.model,
        posterior_samples,        # directly from MCMC
        self.num_samples,
        return_sites=return_sites,
        parallel=self.parallel,
        model_args=args,
        model_kwargs=kwargs,
    ).samples

Sequential Mode (parallel=False)

The _predictive_sequential function (lines 50-73) loops over each MCMC sample:

def _predictive_sequential(model, posterior_samples, model_args, model_kwargs,
                           num_samples, return_site_shapes):
    samples = [{k: v[i] for k, v in posterior_samples.items()}
               for i in range(num_samples)]
    for i in range(num_samples):
        trace = poutine.trace(
            poutine.condition(model, samples[i])
        ).get_trace(*model_args, **model_kwargs)
        # collect site values from trace
    # stack and reshape collected samples

Parallel Mode (parallel=True)

The _predictive function (lines 79-159):

  1. Reshapes posterior samples: (num_samples,) + (1,) * plate_padding + sample_shape
  2. Wraps the model in pyro.plate("_num_predictive_samples", num_samples)
  3. Conditions model on reshaped samples
  4. Runs once and collects all outputs

Key Difference from SVI Usage

Aspect SVI Usage MCMC Usage
guide parameter Trained guide callable None
posterior_samples None (generated from guide) mcmc.get_samples()
num_samples Required (user-specified) Inferred from posterior_samples shape
Internal flow Guide sampled first, then model conditioned Model conditioned directly on posterior_samples

num_samples Inference

When posterior_samples is provided, num_samples is automatically determined from the leading dimension of the sample tensors (lines 205-217):

for name, sample in posterior_samples.items():
    batch_size = sample.shape[0]
    if num_samples is None:
        num_samples = batch_size
    elif num_samples != batch_size:
        warnings.warn(
            "Sample's leading dimension size {} is different from the "
            "provided {} num_samples argument.".format(batch_size, num_samples)
        )
        num_samples = batch_size

Complete Example

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS, Predictive

# Hierarchical model
def model(x, y=None):
    # Hyperpriors
    mu_weight = pyro.sample("mu_weight", dist.Normal(0, 10))
    sigma_weight = pyro.sample("sigma_weight", dist.HalfNormal(5))

    # Group-level parameters
    weight = pyro.sample("weight", dist.Normal(mu_weight, sigma_weight))
    bias = pyro.sample("bias", dist.Normal(0, 10))
    sigma = pyro.sample("sigma", dist.HalfNormal(5))

    mean = weight * x + bias
    with pyro.plate("data", len(x)):
        pyro.sample("obs", dist.Normal(mean, sigma), obs=y)

# Run MCMC
x_train = torch.randn(100)
y_train = 3.0 * x_train + 1.5 + 0.5 * torch.randn(100)

nuts = NUTS(model)
mcmc = MCMC(nuts, num_samples=500, warmup_steps=300)
mcmc.run(x_train, y_train)

# Posterior samples from MCMC
posterior_samples = mcmc.get_samples()
print({k: v.shape for k, v in posterior_samples.items()})
# {'mu_weight': (500,), 'sigma_weight': (500,), 'weight': (500,),
#  'bias': (500,), 'sigma': (500,)}

# MCMC-based posterior predictive
predictive = Predictive(
    model,
    posterior_samples=posterior_samples,
    return_sites=("obs",),
    parallel=True,
)
x_new = torch.linspace(-2, 2, 200)
preds = predictive(x_new)

# Analyze uncertainty
obs_samples = preds["obs"]  # shape: (500, 200)
credible_interval = torch.quantile(obs_samples, torch.tensor([0.05, 0.95]), dim=0)

Related Pages

Implements Principle

Related Implementations

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment