Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Pyro ppl Pyro MCMC Sampler

From Leeroopedia


Metadata

Field Value
Page Type Implementation (API Doc)
Knowledge Sources Repo (Pyro)
Domains MCMC, Bayesian_Inference
Last Updated 2026-02-09 12:00 GMT

Overview

Concrete MCMC sampler class in Pyro that manages the full sampling lifecycle -- initialization, warmup, sampling, and post-processing -- using a pluggable kernel (NUTS or HMC).

Description

MCMC is the primary user-facing class for running Markov chain Monte Carlo inference in Pyro. It orchestrates the complete sampling pipeline:

  1. Initialization: Sets up the kernel, validates parameters, and prepares the initial state. Automatic transforms are applied to constrain parameters to unconstrained space.
  2. Warmup: Runs the kernel for warmup_steps iterations, during which the kernel adapts its internal parameters (step size, mass matrix). Warmup samples are discarded.
  3. Sampling: Runs the kernel for num_samples iterations with frozen kernel parameters. Samples are collected and stored.
  4. Post-processing: Provides methods to retrieve samples, compute summary statistics, and run convergence diagnostics.

The class supports running multiple chains either sequentially or in parallel using Python's multiprocessing module. When num_chains > 1 and a CUDA device is available, chains can be parallelized across CPU processes (each chain runs independently).

Key design decisions:

  • Default warmup: If warmup_steps is not specified, it defaults to num_samples.
  • Parameter saving: The save_params argument allows selective saving of only specific latent variable sites, which is useful for memory efficiency in models with many parameters.
  • Hook functions: An optional hook_fn can be provided to execute custom logic after each sample (e.g., logging, early stopping).

Code Reference

Source Location

Pyro repo, file: pyro/infer/mcmc/api.py, lines L405-650.

Class Hierarchy

MCMC inherits from AbstractMCMC.

Signature

class MCMC(AbstractMCMC):
    def __init__(
        self,
        kernel,
        num_samples,
        warmup_steps=None,
        initial_params=None,
        num_chains=1,
        hook_fn=None,
        mp_context=None,
        disable_progbar=False,
        disable_validation=True,
        transforms=None,
        save_params=None,
    ):

Import

from pyro.infer.mcmc import MCMC

I/O Contract

Constructor Inputs

Parameter Type Required Description
kernel MCMCKernel Yes An MCMC kernel instance (e.g., NUTS or HMC) that defines the transition operator.
num_samples int Yes Number of posterior samples to draw (after warmup).
warmup_steps int No Number of warmup iterations for kernel adaptation. Defaults to num_samples if not specified.
initial_params dict No Dictionary mapping sample site names to initial parameter values. If None, the kernel's init_strategy is used.
num_chains int No Number of independent chains to run. Defaults to 1.
hook_fn callable No Function called after each sample with the kernel and sample as arguments.
mp_context str No Multiprocessing context for parallel chains (e.g., "spawn", "fork").
disable_progbar bool No Whether to disable the progress bar. Defaults to False.
disable_validation bool No Whether to disable distribution validation. Defaults to True for performance.
transforms dict No Transforms for reparameterizing constrained parameters. Overrides kernel transforms if provided.
save_params list No List of site names to save. If None, all latent sites are saved.

Key Methods

Method Signature Description
run run(*args, **kwargs) Runs warmup and sampling. Positional and keyword arguments are passed to the model.
get_samples get_samples(num_samples=None, group_by_chain=False) Returns collected samples as a dictionary of tensors. If group_by_chain=True, tensors have shape (num_chains, num_samples, ...).
summary summary(prob=0.9) Prints a summary table with mean, std, median, credible interval, effective sample size, and R-hat.
diagnostics diagnostics() Returns a dictionary of convergence diagnostics including ESS, R-hat, and divergence information.

Outputs

Output Type Description
MCMC sampler instance MCMC After calling run(), provides access to samples via get_samples(), summary statistics via summary(), and diagnostics via diagnostics().

Usage Examples

Standard MCMC Workflow

import torch
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import NUTS, MCMC

def model(data):
    mu = pyro.sample("mu", dist.Normal(0, 10))
    sigma = pyro.sample("sigma", dist.HalfNormal(10))
    with pyro.plate("data", len(data)):
        pyro.sample("obs", dist.Normal(mu, sigma), obs=data)

data = torch.randn(100) * 2 + 5

# Configure and run MCMC
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500)
mcmc.run(data)

# Retrieve and analyze samples
samples = mcmc.get_samples()
print("mu mean:", samples["mu"].mean().item())
print("sigma mean:", samples["sigma"].mean().item())

# Print summary table
mcmc.summary(prob=0.9)

Multi-Chain MCMC with Diagnostics

from pyro.infer.mcmc import NUTS, MCMC

nuts_kernel = NUTS(model, target_accept_prob=0.8)
mcmc = MCMC(
    nuts_kernel,
    num_samples=1000,
    warmup_steps=500,
    num_chains=4,
)
mcmc.run(data)

# Get samples grouped by chain
samples = mcmc.get_samples(group_by_chain=True)
# samples["mu"].shape == (4, 1000)

# Print summary with diagnostics
mcmc.summary(prob=0.95)

# Access detailed diagnostics
diag = mcmc.diagnostics()
print("Divergences:", diag)

Selective Parameter Saving

from pyro.infer.mcmc import NUTS, MCMC

# Only save specific parameters to conserve memory
mcmc = MCMC(
    NUTS(model),
    num_samples=2000,
    warmup_steps=1000,
    save_params=["mu", "sigma"],
)
mcmc.run(data)
samples = mcmc.get_samples()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment