Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro Predictive SVI

From Leeroopedia


Metadata

Field Value
Implementation ID Pyro_ppl_Pyro_Predictive_SVI
Title Predictive (SVI Usage)
Project Pyro (pyro-ppl/pyro)
File pyro/infer/predictive.py, Lines 162-296
Implements Pyro_ppl_Pyro_Posterior_Predictive_Analysis
Repository https://github.com/pyro-ppl/pyro

Summary

The Predictive class (used with a guide for SVI) constructs the posterior predictive distribution by sampling latent variables from a trained guide and running the model conditioned on those samples. This is the primary tool for generating predictions after SVI training in Pyro.

Signature

class Predictive(torch.nn.Module):
    def __init__(
        self,
        model,
        posterior_samples=None,
        guide=None,
        num_samples=None,
        return_sites=(),
        parallel=False,
    )

Import

from pyro.infer import Predictive

Parameters

Parameter Type Default Description
model callable required Python callable containing Pyro primitives (the generative model).
posterior_samples dict or None None Dictionary of posterior samples. For SVI usage, this is typically None (samples come from the guide instead).
guide callable or None None A trained guide (variational distribution). For SVI usage, pass the trained guide here.
num_samples int or None None Number of samples to draw from the posterior predictive distribution. Required when using a guide.
return_sites list, tuple, or set () Sites to include in the output. Default () returns all sites not in posterior_samples. Use None for all sites. Include "_RETURN" to capture the model's return value.
parallel bool False If True, wraps the model in an outer pyro.plate for vectorized prediction. Requires the model to have all batch dims annotated via pyro.plate.

Returns (forward method)

Type Description
dict A dictionary mapping site names to tensors of shape (num_samples, ...). Each tensor contains num_samples draws from the posterior predictive distribution for that site.

SVI Usage Pattern

The typical SVI workflow for posterior prediction:

# 1. Define model and guide
def model(x, y=None):
    ...
guide = AutoNormal(model)

# 2. Train via SVI
svi = SVI(model, guide, Adam({"lr": 0.01}), loss=Trace_ELBO())
for step in range(num_steps):
    svi.step(x_train, y_train)

# 3. Generate posterior predictive samples
predictive = Predictive(model, guide=guide, num_samples=1000)
samples = predictive(x_test)
# samples is a dict: {"obs": tensor(1000, ...), "weight": tensor(1000, ...), ...}

Internal Mechanism (SVI Path)

When a guide is provided, forward() (lines 254-289) follows this path:

Step 1: Sample from Guide

posterior_samples = _predictive(
    self.guide,
    {},  # empty posterior_samples
    self.num_samples,
    return_sites=None,  # return ALL guide sites
    parallel=self.parallel,
    model_args=args,
    model_kwargs=kwargs,
).samples

This runs the guide num_samples times (or once in parallel) and collects all latent variable samples.

Step 2: Run Model Conditioned on Guide Samples

return _predictive(
    self.model,
    posterior_samples,
    self.num_samples,
    return_sites=return_sites,
    parallel=self.parallel,
    model_args=args,
    model_kwargs=kwargs,
).samples

The model is run conditioned on the guide's samples via poutine.condition(model, posterior_samples).

The _predictive Helper (lines 79-159)

This internal function handles the actual prediction mechanics:

  1. Guesses max_plate_nesting by running the model once
  2. Reshapes posterior samples to align with plate structure
  3. Determines which sites to return based on return_sites
  4. If parallel=True: wraps model in pyro.plate("_num_predictive_samples", num_samples) and runs once
  5. If parallel=False: calls _predictive_sequential which loops num_samples times

Controlling Return Sites

return_sites Behavior
() (default) Returns all sample sites NOT in posterior_samples (typically observed sites)
None Returns ALL sample sites (both latent and observed). Automatically used when a guide is provided and return_sites is empty.
("obs", "weight") Returns only the specified sites
("_RETURN",) Includes the model's return value in the output

Complete Example

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, Predictive
from pyro.optim import Adam

# Bayesian neural network (simplified)
def model(x, y=None):
    w1 = pyro.sample("w1", dist.Normal(torch.zeros(1, 10), torch.ones(1, 10)).to_event(2))
    b1 = pyro.sample("b1", dist.Normal(torch.zeros(10), torch.ones(10)).to_event(1))
    w2 = pyro.sample("w2", dist.Normal(torch.zeros(10, 1), torch.ones(10, 1)).to_event(2))
    b2 = pyro.sample("b2", dist.Normal(torch.zeros(1), torch.ones(1)).to_event(1))
    sigma = pyro.sample("sigma", dist.LogNormal(0, 1))

    hidden = torch.tanh(x @ w1 + b1)
    mean = (hidden @ w2 + b2).squeeze(-1)

    with pyro.plate("data", x.shape[0]):
        pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
    return mean

# Train
from pyro.infer.autoguide import AutoDiagonalNormal
guide = AutoDiagonalNormal(model)
svi = SVI(model, guide, Adam({"lr": 0.005}), Trace_ELBO())
x_train = torch.randn(200, 1)
y_train = torch.sin(x_train).squeeze() + 0.1 * torch.randn(200)
for step in range(2000):
    svi.step(x_train, y_train)

# Predict -- SVI-specific: pass guide, no posterior_samples
predictive = Predictive(
    model,
    guide=guide,
    num_samples=500,
    return_sites=("obs", "_RETURN"),
)
x_test = torch.linspace(-3, 3, 100).unsqueeze(-1)
preds = predictive(x_test)

# preds["obs"]: shape (500, 100) -- posterior predictive samples
# preds["_RETURN"]: shape (500, 100) -- mean predictions
mean = preds["obs"].mean(0)
std = preds["obs"].std(0)

Validation

The __init__ method (lines 188-237) performs the following validation:

  • Either posterior_samples or num_samples must be provided
  • If both guide and posterior_samples are provided, raises ValueError
  • If posterior_samples has mismatched leading dimensions, warns and uses the batch size

Related Pages

Implements Principle

Related Implementations

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment