Implementation:Pyro ppl Pyro WelfordCovariance
Appearance
| Property | Value |
|---|---|
| Module | pyro.ops.welford
|
| Source | pyro/ops/welford.py |
| Lines | 102 |
| Classes | WelfordCovariance, WelfordArrowheadCovariance
|
| Dependencies | torch
|
Overview
This module implements Welford's online algorithm for estimating (co)variance from a stream of samples. The algorithm is numerically stable and requires only O(1) memory per update step (beyond storing the running statistics). It is primarily used in Pyro's HMC/NUTS samplers for adapting the mass matrix during warmup.
Two variants are provided:
WelfordCovariance: Estimates either diagonal or full covariance matrices.WelfordArrowheadCovariance: Estimates covariance in arrowhead form (dense top rows + diagonal bottom), which is more memory-efficient for the arrowhead mass matrix structure.
Both classes support Stan-style regularization to improve conditioning during early adaptation.
Code Reference
Class: WelfordCovariance
Constructor:
diagonal(bool, default True): Whether to estimate only diagonal variance or full covariance.
Methods:
reset(): Resets internal state (_mean,_m2,n_samples).update(sample): Incorporates a new sample using Welford's update rule:- Computes pre-delta and post-delta relative to running mean.
- Updates
_m2with element-wise product (diagonal) or outer product (full).
get_covariance(regularize=True): Returns the estimated covariance (_m2 / (n-1)). With regularization (from Stan), applies shrinkage:cov = (n/(n+5)) * cov + 1e-3 * (5/(n+5)) * I.
Class: WelfordArrowheadCovariance
Constructor:
head_size(int, default 0): Number of dense rows at the top of the arrowhead.
Methods:
reset(): Resets_mean,_m2_top,_m2_bottom_diag,n_samples.update(sample): Updates arrowhead statistics. The top part uses outer products of the head portion; the bottom part uses element-wise products of the tail portion.get_covariance(regularize=True): Returns a tuple(top, bottom_diag)wheretop = cov[:head_size]andbottom_diag = diag(cov)[head_size:]. Applies Stan-style regularization.
I/O Contract
| Class | Method | Input | Output |
|---|---|---|---|
WelfordCovariance |
update |
sample: Tensor(D,) |
(none, mutates state) |
WelfordCovariance |
get_covariance |
regularize: bool |
Tensor(D,) if diagonal, Tensor(D, D) if full
|
WelfordArrowheadCovariance |
update |
sample: Tensor(D,) |
(none, mutates state) |
WelfordArrowheadCovariance |
get_covariance |
regularize: bool |
Tuple (top: Tensor(H, D), bottom_diag: Tensor(D-H,))
|
Usage Examples
import torch
from pyro.ops.welford import WelfordCovariance, WelfordArrowheadCovariance
# Diagonal covariance estimation
welford = WelfordCovariance(diagonal=True)
for _ in range(1000):
sample = torch.randn(5) * torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
welford.update(sample)
var = welford.get_covariance()
print(var) # approximately [1, 4, 9, 16, 25]
# Full covariance estimation
welford_full = WelfordCovariance(diagonal=False)
for _ in range(1000):
welford_full.update(torch.randn(3))
cov = welford_full.get_covariance()
print(cov.shape) # torch.Size([3, 3])
# Arrowhead covariance (head_size=2, total_size=10)
welford_ah = WelfordArrowheadCovariance(head_size=2)
for _ in range(1000):
welford_ah.update(torch.randn(10))
top, bottom_diag = welford_ah.get_covariance()
print(top.shape) # torch.Size([2, 10])
print(bottom_diag.shape) # torch.Size([8])
Related Pages
- Pyro_ppl_Pyro_StreamingStats -- Uses
WelfordCovariancefor streaming variance - Pyro_ppl_Pyro_Arrowhead -- Arrowhead matrix operations for the arrowhead mass matrix
- Pyro_ppl_Pyro_DualAveraging -- Step size adaptation used alongside mass matrix adaptation
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment