Metadata
| Field |
Value
|
| Page Type |
Implementation (API Doc)
|
| Knowledge Sources |
Repo (Pyro)
|
| Domains |
MCMC, Bayesian_Inference, Statistics
|
| Last Updated |
2026-02-09 12:00 GMT
|
Overview
Concrete diagnostic tools for assessing MCMC convergence in Pyro, including the summary() and diagnostics() methods on the MCMC class, backed by effective_sample_size() and split_gelman_rubin() from pyro.ops.stats.
Description
Pyro provides MCMC convergence diagnostics through two mechanisms:
MCMC.summary()
The summary() method on the MCMC class prints a formatted table containing, for each sampled parameter:
- mean: The posterior mean across all samples and chains.
- std: The posterior standard deviation.
- median: The posterior median.
- Credible interval: The central credible interval at the specified probability level (e.g., 5th and 95th percentiles for
prob=0.9).
- n_eff: The effective sample size (ESS), computed by
effective_sample_size().
- r_hat: The split Gelman-Rubin R-hat statistic, computed by
split_gelman_rubin().
MCMC.diagnostics()
The diagnostics() method returns a dictionary containing:
- Per-parameter ESS and R-hat values.
- Divergence information from the MCMC kernel (number of divergent transitions encountered during sampling).
Supporting Functions
The statistical computations are implemented in pyro/ops/stats.py:
effective_sample_size(input, chain_dim=0, sample_dim=1): Computes ESS for a tensor of MCMC samples. Uses the autocorrelation-based estimator with the initial positive sequence truncation method. The input tensor should have shape (num_chains, num_samples, ...).
split_gelman_rubin(input, chain_dim=0, sample_dim=1): Computes the split R-hat diagnostic. Splits each chain in half, then computes the ratio of the pooled posterior variance estimate to the within-chain variance. Values close to 1.0 indicate convergence.
Both functions operate on tensors with explicit chain and sample dimensions, making them compatible with single-chain (split in half) and multi-chain analyses.
Code Reference
Source Locations
- MCMC.summary():
pyro/infer/mcmc/api.py, lines L617-640.
- MCMC.diagnostics():
pyro/infer/mcmc/api.py, lines L641-650.
- effective_sample_size():
pyro/ops/stats.py.
- split_gelman_rubin():
pyro/ops/stats.py.
Signatures
# On MCMC class (pyro/infer/mcmc/api.py)
class MCMC:
def summary(self, prob=0.9):
"""Print summary statistics for each parameter."""
...
def diagnostics(self):
"""Return a dictionary of convergence diagnostics."""
...
# Supporting functions (pyro/ops/stats.py)
def effective_sample_size(input, chain_dim=0, sample_dim=1):
"""
Compute effective sample size of MCMC samples.
:param input: Tensor of shape (num_chains, num_samples, ...).
:param chain_dim: Dimension indexing chains.
:param sample_dim: Dimension indexing samples within each chain.
:return: Tensor of ESS values.
"""
...
def split_gelman_rubin(input, chain_dim=0, sample_dim=1):
"""
Compute split Gelman-Rubin R-hat diagnostic.
:param input: Tensor of shape (num_chains, num_samples, ...).
:param chain_dim: Dimension indexing chains.
:param sample_dim: Dimension indexing samples within each chain.
:return: Tensor of R-hat values.
"""
...
Import
# summary() and diagnostics() are accessed via the MCMC instance
from pyro.infer.mcmc import MCMC
# Supporting functions can be imported directly
from pyro.ops.stats import effective_sample_size, split_gelman_rubin
I/O Contract
MCMC.summary()
| Parameter |
Type |
Required |
Description
|
prob |
float |
No |
The probability mass for the credible interval. Defaults to 0.9 (i.e., 5th to 95th percentile).
|
| Output |
Type |
Description
|
| Printed table |
None |
Prints a formatted table to stdout with mean, std, median, credible interval bounds, n_eff, and r_hat for each parameter. Returns None.
|
MCMC.diagnostics()
| Output |
Type |
Description
|
| Diagnostics dict |
dict |
Dictionary containing per-parameter ESS and R-hat values, and divergence count from the kernel.
|
effective_sample_size()
| Parameter |
Type |
Required |
Description
|
input |
torch.Tensor |
Yes |
MCMC samples with shape (num_chains, num_samples, ...).
|
chain_dim |
int |
No |
Dimension indexing chains. Defaults to 0.
|
sample_dim |
int |
No |
Dimension indexing samples. Defaults to 1.
|
| Output |
Type |
Description
|
| ESS |
torch.Tensor |
Effective sample size for each parameter dimension. Shape matches the input with chain and sample dimensions removed.
|
split_gelman_rubin()
| Parameter |
Type |
Required |
Description
|
input |
torch.Tensor |
Yes |
MCMC samples with shape (num_chains, num_samples, ...).
|
chain_dim |
int |
No |
Dimension indexing chains. Defaults to 0.
|
sample_dim |
int |
No |
Dimension indexing samples. Defaults to 1.
|
| Output |
Type |
Description
|
| R-hat |
torch.Tensor |
Split Gelman-Rubin R-hat values. Shape matches input with chain and sample dimensions removed. Values near 1.0 indicate convergence.
|
Usage Examples
Printing Summary After MCMC
import torch
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import NUTS, MCMC
def model(data):
mu = pyro.sample("mu", dist.Normal(0, 10))
sigma = pyro.sample("sigma", dist.HalfNormal(10))
with pyro.plate("data", len(data)):
pyro.sample("obs", dist.Normal(mu, sigma), obs=data)
data = torch.randn(100) * 2 + 5
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500, num_chains=4)
mcmc.run(data)
# Print summary table with 90% credible intervals
mcmc.summary(prob=0.9)
# Output:
# mean std median 5.0% 95.0% n_eff r_hat
# mu 5.01 0.20 5.01 4.68 5.34 800.0 1.00
# sigma 2.03 0.14 2.02 1.80 2.27 750.0 1.00
Accessing Diagnostics Programmatically
# Get diagnostics dictionary
diag = mcmc.diagnostics()
print(diag)
# Contains ESS, R-hat, and divergence counts
# Direct use of supporting functions
from pyro.ops.stats import effective_sample_size, split_gelman_rubin
samples = mcmc.get_samples(group_by_chain=True)
mu_samples = samples["mu"] # shape: (4, 1000)
ess = effective_sample_size(mu_samples.unsqueeze(-1))
rhat = split_gelman_rubin(mu_samples.unsqueeze(-1))
print(f"ESS for mu: {ess.item():.0f}")
print(f"R-hat for mu: {rhat.item():.3f}")
Related Pages