Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro Stats

From Leeroopedia


Property Value
Module pyro.ops.stats
Source pyro/ops/stats.py
Lines 605
Functions gelman_rubin, split_gelman_rubin, autocorrelation, autocovariance, effective_sample_size, resample, quantile, weighed_quantile, pi, hpdi, waic, fit_generalized_pareto, crps_empirical, energy_score_empirical
Dependencies torch, pyro.ops.tensor_utils

Overview

This module provides a comprehensive suite of statistical diagnostic and scoring functions used for MCMC convergence diagnostics, posterior summarization, and model comparison. All functions operate on PyTorch tensors and are compatible with batched computations.

The module includes:

  • MCMC diagnostics: Gelman-Rubin R-hat, split R-hat, effective sample size, autocorrelation
  • Posterior summaries: Quantiles, weighted quantiles, percentile intervals, highest posterior density intervals (HPDI)
  • Model comparison: WAIC (Widely Applicable Information Criterion)
  • Distribution fitting: Generalized Pareto distribution fitting (for Pareto-smoothed importance sampling)
  • Scoring rules: CRPS (Continuous Ranked Probability Score), Energy Score

Code Reference

MCMC Diagnostics

  • gelman_rubin(input, chain_dim, sample_dim): Computes R-hat convergence diagnostic over multiple chains. Requires at least 2 chains and 2 samples.
  • split_gelman_rubin(input, chain_dim, sample_dim): Computes split R-hat, which splits each chain in half. Requires at least 4 samples.
  • autocorrelation(input, dim): FFT-based autocorrelation using the Stan algorithm. Uses next_fast_len for efficient FFT sizes.
  • autocovariance(input, dim): Computes autocovariance as autocorrelation * variance.
  • effective_sample_size(input, chain_dim, sample_dim): Estimates the effective number of independent samples using initial positive sequence estimator from Geyer (2011) with monotone sequence correction.

Posterior Summaries

  • quantile(input, probs, dim): Computes quantiles with linear interpolation between adjacent sorted values.
  • weighed_quantile(input, probs, log_weights, dim): Computes quantiles of weighted samples using cumulative weight interpolation.
  • pi(input, prob, dim): Computes equal-tailed percentile intervals.
  • hpdi(input, prob, dim): Computes the highest posterior density interval -- the narrowest interval containing the given probability mass.

Model Comparison

  • waic(input, log_weights, pointwise, dim): Computes WAIC and effective number of parameters following Vehtari & Gelman.
  • fit_generalized_pareto(X): Fits the Generalized Pareto Distribution using the Zhang & Stephens method, useful for Pareto-smoothed importance sampling.

Scoring Rules

  • crps_empirical(pred, truth): Computes the empirical Continuous Ranked Probability Score in O(n log n) time using a sorted-difference formulation.
  • energy_score_empirical(pred, truth, pred_batch_size, cdist): Computes the multivariate Energy Score (generalizes CRPS to multiple dimensions) in O(n^2) time.

I/O Contract

Function Input Output
gelman_rubin input: Tensor, chain_dim: int, sample_dim: int Tensor (R-hat values)
effective_sample_size input: Tensor, chain_dim: int, sample_dim: int Tensor (ESS values)
autocorrelation input: Tensor, dim: int Tensor (same shape, autocorrelation at each lag)
quantile input: Tensor, probs: list/Tensor, dim: int Tensor (quantile values)
hpdi input: Tensor, prob: float, dim: int Tensor (lower and upper bounds)
waic input: Tensor (log likelihoods) Tuple of (waic, p_waic)
crps_empirical pred: Tensor(N, ...), truth: Tensor(...) Tensor(...)
energy_score_empirical pred: Tensor(..., N, D), truth: Tensor(..., D) Tensor
fit_generalized_pareto X: Tensor(N,) Tuple (k: float, sigma: float)

Usage Examples

import torch
from pyro.ops.stats import (
    effective_sample_size,
    gelman_rubin,
    hpdi,
    quantile,
    crps_empirical,
)

# MCMC diagnostics with 4 chains, 1000 samples each
samples = torch.randn(1000, 4)  # sample_dim=0, chain_dim=1
rhat = gelman_rubin(samples, chain_dim=1, sample_dim=0)
ess = effective_sample_size(samples, chain_dim=1, sample_dim=0)
print(f"R-hat: {rhat.item():.3f}, ESS: {ess.item():.1f}")

# Posterior summary
samples_flat = torch.randn(5000)
q = quantile(samples_flat, [0.025, 0.5, 0.975])
interval = hpdi(samples_flat, prob=0.95)

# Scoring predictions
pred = torch.randn(100, 50)  # 100 samples, 50 data points
truth = torch.randn(50)
score = crps_empirical(pred, truth)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment