Implementation:Pyro ppl Pyro Stats
Appearance
| Property | Value |
|---|---|
| Module | pyro.ops.stats
|
| Source | pyro/ops/stats.py |
| Lines | 605 |
| Functions | gelman_rubin, split_gelman_rubin, autocorrelation, autocovariance, effective_sample_size, resample, quantile, weighed_quantile, pi, hpdi, waic, fit_generalized_pareto, crps_empirical, energy_score_empirical
|
| Dependencies | torch, pyro.ops.tensor_utils
|
Overview
This module provides a comprehensive suite of statistical diagnostic and scoring functions used for MCMC convergence diagnostics, posterior summarization, and model comparison. All functions operate on PyTorch tensors and are compatible with batched computations.
The module includes:
- MCMC diagnostics: Gelman-Rubin R-hat, split R-hat, effective sample size, autocorrelation
- Posterior summaries: Quantiles, weighted quantiles, percentile intervals, highest posterior density intervals (HPDI)
- Model comparison: WAIC (Widely Applicable Information Criterion)
- Distribution fitting: Generalized Pareto distribution fitting (for Pareto-smoothed importance sampling)
- Scoring rules: CRPS (Continuous Ranked Probability Score), Energy Score
Code Reference
MCMC Diagnostics
gelman_rubin(input, chain_dim, sample_dim): Computes R-hat convergence diagnostic over multiple chains. Requires at least 2 chains and 2 samples.split_gelman_rubin(input, chain_dim, sample_dim): Computes split R-hat, which splits each chain in half. Requires at least 4 samples.autocorrelation(input, dim): FFT-based autocorrelation using the Stan algorithm. Usesnext_fast_lenfor efficient FFT sizes.autocovariance(input, dim): Computes autocovariance asautocorrelation * variance.effective_sample_size(input, chain_dim, sample_dim): Estimates the effective number of independent samples using initial positive sequence estimator from Geyer (2011) with monotone sequence correction.
Posterior Summaries
quantile(input, probs, dim): Computes quantiles with linear interpolation between adjacent sorted values.weighed_quantile(input, probs, log_weights, dim): Computes quantiles of weighted samples using cumulative weight interpolation.pi(input, prob, dim): Computes equal-tailed percentile intervals.hpdi(input, prob, dim): Computes the highest posterior density interval -- the narrowest interval containing the given probability mass.
Model Comparison
waic(input, log_weights, pointwise, dim): Computes WAIC and effective number of parameters following Vehtari & Gelman.fit_generalized_pareto(X): Fits the Generalized Pareto Distribution using the Zhang & Stephens method, useful for Pareto-smoothed importance sampling.
Scoring Rules
crps_empirical(pred, truth): Computes the empirical Continuous Ranked Probability Score in O(n log n) time using a sorted-difference formulation.energy_score_empirical(pred, truth, pred_batch_size, cdist): Computes the multivariate Energy Score (generalizes CRPS to multiple dimensions) in O(n^2) time.
I/O Contract
| Function | Input | Output |
|---|---|---|
gelman_rubin |
input: Tensor, chain_dim: int, sample_dim: int |
Tensor (R-hat values)
|
effective_sample_size |
input: Tensor, chain_dim: int, sample_dim: int |
Tensor (ESS values)
|
autocorrelation |
input: Tensor, dim: int |
Tensor (same shape, autocorrelation at each lag)
|
quantile |
input: Tensor, probs: list/Tensor, dim: int |
Tensor (quantile values)
|
hpdi |
input: Tensor, prob: float, dim: int |
Tensor (lower and upper bounds)
|
waic |
input: Tensor (log likelihoods) |
Tuple of (waic, p_waic)
|
crps_empirical |
pred: Tensor(N, ...), truth: Tensor(...) |
Tensor(...)
|
energy_score_empirical |
pred: Tensor(..., N, D), truth: Tensor(..., D) |
Tensor
|
fit_generalized_pareto |
X: Tensor(N,) |
Tuple (k: float, sigma: float)
|
Usage Examples
import torch
from pyro.ops.stats import (
effective_sample_size,
gelman_rubin,
hpdi,
quantile,
crps_empirical,
)
# MCMC diagnostics with 4 chains, 1000 samples each
samples = torch.randn(1000, 4) # sample_dim=0, chain_dim=1
rhat = gelman_rubin(samples, chain_dim=1, sample_dim=0)
ess = effective_sample_size(samples, chain_dim=1, sample_dim=0)
print(f"R-hat: {rhat.item():.3f}, ESS: {ess.item():.1f}")
# Posterior summary
samples_flat = torch.randn(5000)
q = quantile(samples_flat, [0.025, 0.5, 0.975])
interval = hpdi(samples_flat, prob=0.95)
# Scoring predictions
pred = torch.randn(100, 50) # 100 samples, 50 data points
truth = torch.randn(50)
score = crps_empirical(pred, truth)
Related Pages
- Pyro_ppl_Pyro_TensorUtils -- Provides
next_fast_lenused by autocorrelation - Pyro_ppl_Pyro_StreamingStats -- Online statistics computation
- Pyro_ppl_Pyro_WelfordCovariance -- Online covariance estimation
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment