Implementation:Pyro ppl Pyro StreamingStats
| Property | Value |
|---|---|
| Module | pyro.ops.streaming
|
| Source | pyro/ops/streaming.py |
| Lines | 278 |
| Classes | StreamingStats (ABC), CountStats, StatsOfDict, StackStats, CountMeanStats, CountMeanVarianceStats
|
| Dependencies | torch, pyro.ops.welford
|
Overview
This module provides a framework for computing streaming (online) statistics over samples, designed for use with MCMC and other iterative inference algorithms. The abstract base class StreamingStats defines three core operations:
update(sample): Incrementally incorporate a new sample.merge(other): Combine two independent statistics (e.g., from different MCMC chains).get(): Retrieve the computed statistics.
The merge operation enables parallel computation and aggregation across chains while the update operation keeps memory constant regardless of the number of samples.
Code Reference
Class: StreamingStats (ABC)
Abstract base class defining the streaming statistics interface. All subclasses must implement update, merge, and get.
Class: CountStats
Tracks only the number of samples. Returns {"count": int}.
Class: StatsOfDict
Computes independent statistics for each key in dictionary-valued samples. Accepts a mapping from key to statistic class and a default statistic class.
stats = StatsOfDict({"a": CountStats, "b": CountMeanStats})
stats.update({"a": tensor_a, "b": tensor_b})
summary = stats.get() # {"a": {"count": 1}, "b": {"count": 1, "mean": ...}}
Class: StackStats
Collects all samples into a single stacked tensor. Returns {"count": int, "samples": Tensor}.
Class: CountMeanStats
Tracks count and running mean using Welford-style incremental updates. Returns {"count": int, "mean": Tensor}.
Class: CountMeanVarianceStats
Tracks count, mean, and diagonal variance using WelfordCovariance(diagonal=True). Supports merging across chains using the parallel Welford algorithm. Returns {"count": int, "mean": Tensor, "variance": Tensor}.
I/O Contract
| Class | update Input | get Output |
|---|---|---|
CountStats |
Any sample | {"count": int}
|
StatsOfDict |
Dict[key, sample] |
Dict[key, stats_dict]
|
StackStats |
Tensor |
{"count": int, "samples": Tensor}
|
CountMeanStats |
Tensor |
{"count": int, "mean": Tensor}
|
CountMeanVarianceStats |
Tensor (fixed shape) |
{"count": int, "mean": Tensor, "variance": Tensor}
|
Usage Examples
import torch
from pyro.ops.streaming import (
CountMeanVarianceStats,
StatsOfDict,
CountMeanStats,
)
# Track mean and variance of MCMC samples
stats = CountMeanVarianceStats()
for _ in range(1000):
sample = torch.randn(10)
stats.update(sample)
result = stats.get()
print(f"Count: {result['count']}")
print(f"Mean shape: {result['mean'].shape}") # torch.Size([10])
print(f"Variance shape: {result['variance'].shape}") # torch.Size([10])
# Merge statistics from two chains
chain1_stats = CountMeanVarianceStats()
chain2_stats = CountMeanVarianceStats()
for _ in range(500):
chain1_stats.update(torch.randn(5))
chain2_stats.update(torch.randn(5) + 1.0)
merged = chain1_stats.merge(chain2_stats)
print(merged.get()["count"]) # 1000
# Track different statistics per key
stats = StatsOfDict({"loc": CountMeanStats, "scale": CountMeanVarianceStats})
stats.update({"loc": torch.tensor(1.0), "scale": torch.tensor(0.5)})
summary = stats.get()
Related Pages
- Pyro_ppl_Pyro_WelfordCovariance -- Underlying online covariance algorithm
- Pyro_ppl_Pyro_Stats -- Batch statistical functions (non-streaming)