Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro MCMC Summary

From Leeroopedia


Metadata

Field Value
Page Type Implementation (API Doc)
Knowledge Sources Repo (Pyro)
Domains MCMC, Bayesian_Inference, Statistics
Last Updated 2026-02-09 12:00 GMT

Overview

Concrete diagnostic tools for assessing MCMC convergence in Pyro, including the summary() and diagnostics() methods on the MCMC class, backed by effective_sample_size() and split_gelman_rubin() from pyro.ops.stats.

Description

Pyro provides MCMC convergence diagnostics through two mechanisms:

MCMC.summary()

The summary() method on the MCMC class prints a formatted table containing, for each sampled parameter:

  • mean: The posterior mean across all samples and chains.
  • std: The posterior standard deviation.
  • median: The posterior median.
  • Credible interval: The central credible interval at the specified probability level (e.g., 5th and 95th percentiles for prob=0.9).
  • n_eff: The effective sample size (ESS), computed by effective_sample_size().
  • r_hat: The split Gelman-Rubin R-hat statistic, computed by split_gelman_rubin().

MCMC.diagnostics()

The diagnostics() method returns a dictionary containing:

  • Per-parameter ESS and R-hat values.
  • Divergence information from the MCMC kernel (number of divergent transitions encountered during sampling).

Supporting Functions

The statistical computations are implemented in pyro/ops/stats.py:

  • effective_sample_size(input, chain_dim=0, sample_dim=1): Computes ESS for a tensor of MCMC samples. Uses the autocorrelation-based estimator with the initial positive sequence truncation method. The input tensor should have shape (num_chains, num_samples, ...).
  • split_gelman_rubin(input, chain_dim=0, sample_dim=1): Computes the split R-hat diagnostic. Splits each chain in half, then computes the ratio of the pooled posterior variance estimate to the within-chain variance. Values close to 1.0 indicate convergence.

Both functions operate on tensors with explicit chain and sample dimensions, making them compatible with single-chain (split in half) and multi-chain analyses.

Code Reference

Source Locations

  • MCMC.summary(): pyro/infer/mcmc/api.py, lines L617-640.
  • MCMC.diagnostics(): pyro/infer/mcmc/api.py, lines L641-650.
  • effective_sample_size(): pyro/ops/stats.py.
  • split_gelman_rubin(): pyro/ops/stats.py.

Signatures

# On MCMC class (pyro/infer/mcmc/api.py)
class MCMC:
    def summary(self, prob=0.9):
        """Print summary statistics for each parameter."""
        ...

    def diagnostics(self):
        """Return a dictionary of convergence diagnostics."""
        ...
# Supporting functions (pyro/ops/stats.py)
def effective_sample_size(input, chain_dim=0, sample_dim=1):
    """
    Compute effective sample size of MCMC samples.

    :param input: Tensor of shape (num_chains, num_samples, ...).
    :param chain_dim: Dimension indexing chains.
    :param sample_dim: Dimension indexing samples within each chain.
    :return: Tensor of ESS values.
    """
    ...

def split_gelman_rubin(input, chain_dim=0, sample_dim=1):
    """
    Compute split Gelman-Rubin R-hat diagnostic.

    :param input: Tensor of shape (num_chains, num_samples, ...).
    :param chain_dim: Dimension indexing chains.
    :param sample_dim: Dimension indexing samples within each chain.
    :return: Tensor of R-hat values.
    """
    ...

Import

# summary() and diagnostics() are accessed via the MCMC instance
from pyro.infer.mcmc import MCMC

# Supporting functions can be imported directly
from pyro.ops.stats import effective_sample_size, split_gelman_rubin

I/O Contract

MCMC.summary()

Parameter Type Required Description
prob float No The probability mass for the credible interval. Defaults to 0.9 (i.e., 5th to 95th percentile).
Output Type Description
Printed table None Prints a formatted table to stdout with mean, std, median, credible interval bounds, n_eff, and r_hat for each parameter. Returns None.

MCMC.diagnostics()

Output Type Description
Diagnostics dict dict Dictionary containing per-parameter ESS and R-hat values, and divergence count from the kernel.

effective_sample_size()

Parameter Type Required Description
input torch.Tensor Yes MCMC samples with shape (num_chains, num_samples, ...).
chain_dim int No Dimension indexing chains. Defaults to 0.
sample_dim int No Dimension indexing samples. Defaults to 1.
Output Type Description
ESS torch.Tensor Effective sample size for each parameter dimension. Shape matches the input with chain and sample dimensions removed.

split_gelman_rubin()

Parameter Type Required Description
input torch.Tensor Yes MCMC samples with shape (num_chains, num_samples, ...).
chain_dim int No Dimension indexing chains. Defaults to 0.
sample_dim int No Dimension indexing samples. Defaults to 1.
Output Type Description
R-hat torch.Tensor Split Gelman-Rubin R-hat values. Shape matches input with chain and sample dimensions removed. Values near 1.0 indicate convergence.

Usage Examples

Printing Summary After MCMC

import torch
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import NUTS, MCMC

def model(data):
    mu = pyro.sample("mu", dist.Normal(0, 10))
    sigma = pyro.sample("sigma", dist.HalfNormal(10))
    with pyro.plate("data", len(data)):
        pyro.sample("obs", dist.Normal(mu, sigma), obs=data)

data = torch.randn(100) * 2 + 5

nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500, num_chains=4)
mcmc.run(data)

# Print summary table with 90% credible intervals
mcmc.summary(prob=0.9)
# Output:
#          mean    std  median   5.0%  95.0%  n_eff  r_hat
#    mu    5.01   0.20    5.01   4.68   5.34  800.0   1.00
# sigma    2.03   0.14    2.02   1.80   2.27  750.0   1.00

Accessing Diagnostics Programmatically

# Get diagnostics dictionary
diag = mcmc.diagnostics()
print(diag)
# Contains ESS, R-hat, and divergence counts

# Direct use of supporting functions
from pyro.ops.stats import effective_sample_size, split_gelman_rubin

samples = mcmc.get_samples(group_by_chain=True)
mu_samples = samples["mu"]  # shape: (4, 1000)

ess = effective_sample_size(mu_samples.unsqueeze(-1))
rhat = split_gelman_rubin(mu_samples.unsqueeze(-1))
print(f"ESS for mu: {ess.item():.0f}")
print(f"R-hat for mu: {rhat.item():.3f}")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment