Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro WelfordCovariance

From Leeroopedia


Property Value
Module pyro.ops.welford
Source pyro/ops/welford.py
Lines 102
Classes WelfordCovariance, WelfordArrowheadCovariance
Dependencies torch

Overview

This module implements Welford's online algorithm for estimating (co)variance from a stream of samples. The algorithm is numerically stable and requires only O(1) memory per update step (beyond storing the running statistics). It is primarily used in Pyro's HMC/NUTS samplers for adapting the mass matrix during warmup.

Two variants are provided:

  • WelfordCovariance: Estimates either diagonal or full covariance matrices.
  • WelfordArrowheadCovariance: Estimates covariance in arrowhead form (dense top rows + diagonal bottom), which is more memory-efficient for the arrowhead mass matrix structure.

Both classes support Stan-style regularization to improve conditioning during early adaptation.

Code Reference

Class: WelfordCovariance

Constructor:

  • diagonal (bool, default True): Whether to estimate only diagonal variance or full covariance.

Methods:

  • reset(): Resets internal state (_mean, _m2, n_samples).
  • update(sample): Incorporates a new sample using Welford's update rule:
    • Computes pre-delta and post-delta relative to running mean.
    • Updates _m2 with element-wise product (diagonal) or outer product (full).
  • get_covariance(regularize=True): Returns the estimated covariance (_m2 / (n-1)). With regularization (from Stan), applies shrinkage: cov = (n/(n+5)) * cov + 1e-3 * (5/(n+5)) * I.

Class: WelfordArrowheadCovariance

Constructor:

  • head_size (int, default 0): Number of dense rows at the top of the arrowhead.

Methods:

  • reset(): Resets _mean, _m2_top, _m2_bottom_diag, n_samples.
  • update(sample): Updates arrowhead statistics. The top part uses outer products of the head portion; the bottom part uses element-wise products of the tail portion.
  • get_covariance(regularize=True): Returns a tuple (top, bottom_diag) where top = cov[:head_size] and bottom_diag = diag(cov)[head_size:]. Applies Stan-style regularization.

I/O Contract

Class Method Input Output
WelfordCovariance update sample: Tensor(D,) (none, mutates state)
WelfordCovariance get_covariance regularize: bool Tensor(D,) if diagonal, Tensor(D, D) if full
WelfordArrowheadCovariance update sample: Tensor(D,) (none, mutates state)
WelfordArrowheadCovariance get_covariance regularize: bool Tuple (top: Tensor(H, D), bottom_diag: Tensor(D-H,))

Usage Examples

import torch
from pyro.ops.welford import WelfordCovariance, WelfordArrowheadCovariance

# Diagonal covariance estimation
welford = WelfordCovariance(diagonal=True)
for _ in range(1000):
    sample = torch.randn(5) * torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
    welford.update(sample)
var = welford.get_covariance()
print(var)  # approximately [1, 4, 9, 16, 25]

# Full covariance estimation
welford_full = WelfordCovariance(diagonal=False)
for _ in range(1000):
    welford_full.update(torch.randn(3))
cov = welford_full.get_covariance()
print(cov.shape)  # torch.Size([3, 3])

# Arrowhead covariance (head_size=2, total_size=10)
welford_ah = WelfordArrowheadCovariance(head_size=2)
for _ in range(1000):
    welford_ah.update(torch.randn(10))
top, bottom_diag = welford_ah.get_covariance()
print(top.shape)          # torch.Size([2, 10])
print(bottom_diag.shape)  # torch.Size([8])

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment