Implementation:Pyro ppl Pyro MCMC Sampler
Metadata
| Field | Value |
|---|---|
| Page Type | Implementation (API Doc) |
| Knowledge Sources | Repo (Pyro) |
| Domains | MCMC, Bayesian_Inference |
| Last Updated | 2026-02-09 12:00 GMT |
Overview
Concrete MCMC sampler class in Pyro that manages the full sampling lifecycle -- initialization, warmup, sampling, and post-processing -- using a pluggable kernel (NUTS or HMC).
Description
MCMC is the primary user-facing class for running Markov chain Monte Carlo inference in Pyro. It orchestrates the complete sampling pipeline:
- Initialization: Sets up the kernel, validates parameters, and prepares the initial state. Automatic transforms are applied to constrain parameters to unconstrained space.
- Warmup: Runs the kernel for
warmup_stepsiterations, during which the kernel adapts its internal parameters (step size, mass matrix). Warmup samples are discarded. - Sampling: Runs the kernel for
num_samplesiterations with frozen kernel parameters. Samples are collected and stored. - Post-processing: Provides methods to retrieve samples, compute summary statistics, and run convergence diagnostics.
The class supports running multiple chains either sequentially or in parallel using Python's multiprocessing module. When num_chains > 1 and a CUDA device is available, chains can be parallelized across CPU processes (each chain runs independently).
Key design decisions:
- Default warmup: If
warmup_stepsis not specified, it defaults tonum_samples. - Parameter saving: The
save_paramsargument allows selective saving of only specific latent variable sites, which is useful for memory efficiency in models with many parameters. - Hook functions: An optional
hook_fncan be provided to execute custom logic after each sample (e.g., logging, early stopping).
Code Reference
Source Location
Pyro repo, file: pyro/infer/mcmc/api.py, lines L405-650.
Class Hierarchy
MCMC inherits from AbstractMCMC.
Signature
class MCMC(AbstractMCMC):
def __init__(
self,
kernel,
num_samples,
warmup_steps=None,
initial_params=None,
num_chains=1,
hook_fn=None,
mp_context=None,
disable_progbar=False,
disable_validation=True,
transforms=None,
save_params=None,
):
Import
from pyro.infer.mcmc import MCMC
I/O Contract
Constructor Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
kernel |
MCMCKernel | Yes | An MCMC kernel instance (e.g., NUTS or HMC) that defines the transition operator.
|
num_samples |
int | Yes | Number of posterior samples to draw (after warmup). |
warmup_steps |
int | No | Number of warmup iterations for kernel adaptation. Defaults to num_samples if not specified.
|
initial_params |
dict | No | Dictionary mapping sample site names to initial parameter values. If None, the kernel's init_strategy is used.
|
num_chains |
int | No | Number of independent chains to run. Defaults to 1.
|
hook_fn |
callable | No | Function called after each sample with the kernel and sample as arguments. |
mp_context |
str | No | Multiprocessing context for parallel chains (e.g., "spawn", "fork").
|
disable_progbar |
bool | No | Whether to disable the progress bar. Defaults to False.
|
disable_validation |
bool | No | Whether to disable distribution validation. Defaults to True for performance.
|
transforms |
dict | No | Transforms for reparameterizing constrained parameters. Overrides kernel transforms if provided. |
save_params |
list | No | List of site names to save. If None, all latent sites are saved.
|
Key Methods
| Method | Signature | Description |
|---|---|---|
run |
run(*args, **kwargs) |
Runs warmup and sampling. Positional and keyword arguments are passed to the model. |
get_samples |
get_samples(num_samples=None, group_by_chain=False) |
Returns collected samples as a dictionary of tensors. If group_by_chain=True, tensors have shape (num_chains, num_samples, ...).
|
summary |
summary(prob=0.9) |
Prints a summary table with mean, std, median, credible interval, effective sample size, and R-hat. |
diagnostics |
diagnostics() |
Returns a dictionary of convergence diagnostics including ESS, R-hat, and divergence information. |
Outputs
| Output | Type | Description |
|---|---|---|
| MCMC sampler instance | MCMC |
After calling run(), provides access to samples via get_samples(), summary statistics via summary(), and diagnostics via diagnostics().
|
Usage Examples
Standard MCMC Workflow
import torch
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import NUTS, MCMC
def model(data):
mu = pyro.sample("mu", dist.Normal(0, 10))
sigma = pyro.sample("sigma", dist.HalfNormal(10))
with pyro.plate("data", len(data)):
pyro.sample("obs", dist.Normal(mu, sigma), obs=data)
data = torch.randn(100) * 2 + 5
# Configure and run MCMC
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500)
mcmc.run(data)
# Retrieve and analyze samples
samples = mcmc.get_samples()
print("mu mean:", samples["mu"].mean().item())
print("sigma mean:", samples["sigma"].mean().item())
# Print summary table
mcmc.summary(prob=0.9)
Multi-Chain MCMC with Diagnostics
from pyro.infer.mcmc import NUTS, MCMC
nuts_kernel = NUTS(model, target_accept_prob=0.8)
mcmc = MCMC(
nuts_kernel,
num_samples=1000,
warmup_steps=500,
num_chains=4,
)
mcmc.run(data)
# Get samples grouped by chain
samples = mcmc.get_samples(group_by_chain=True)
# samples["mu"].shape == (4, 1000)
# Print summary with diagnostics
mcmc.summary(prob=0.95)
# Access detailed diagnostics
diag = mcmc.diagnostics()
print("Divergences:", diag)
Selective Parameter Saving
from pyro.infer.mcmc import NUTS, MCMC
# Only save specific parameters to conserve memory
mcmc = MCMC(
NUTS(model),
num_samples=2000,
warmup_steps=1000,
save_params=["mu", "sigma"],
)
mcmc.run(data)
samples = mcmc.get_samples()