Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro StreamingStats

From Leeroopedia


Property Value
Module pyro.ops.streaming
Source pyro/ops/streaming.py
Lines 278
Classes StreamingStats (ABC), CountStats, StatsOfDict, StackStats, CountMeanStats, CountMeanVarianceStats
Dependencies torch, pyro.ops.welford

Overview

This module provides a framework for computing streaming (online) statistics over samples, designed for use with MCMC and other iterative inference algorithms. The abstract base class StreamingStats defines three core operations:

  • update(sample): Incrementally incorporate a new sample.
  • merge(other): Combine two independent statistics (e.g., from different MCMC chains).
  • get(): Retrieve the computed statistics.

The merge operation enables parallel computation and aggregation across chains while the update operation keeps memory constant regardless of the number of samples.

Code Reference

Class: StreamingStats (ABC)

Abstract base class defining the streaming statistics interface. All subclasses must implement update, merge, and get.

Class: CountStats

Tracks only the number of samples. Returns {"count": int}.

Class: StatsOfDict

Computes independent statistics for each key in dictionary-valued samples. Accepts a mapping from key to statistic class and a default statistic class.

stats = StatsOfDict({"a": CountStats, "b": CountMeanStats})
stats.update({"a": tensor_a, "b": tensor_b})
summary = stats.get()  # {"a": {"count": 1}, "b": {"count": 1, "mean": ...}}

Class: StackStats

Collects all samples into a single stacked tensor. Returns {"count": int, "samples": Tensor}.

Class: CountMeanStats

Tracks count and running mean using Welford-style incremental updates. Returns {"count": int, "mean": Tensor}.

Class: CountMeanVarianceStats

Tracks count, mean, and diagonal variance using WelfordCovariance(diagonal=True). Supports merging across chains using the parallel Welford algorithm. Returns {"count": int, "mean": Tensor, "variance": Tensor}.

I/O Contract

Class update Input get Output
CountStats Any sample {"count": int}
StatsOfDict Dict[key, sample] Dict[key, stats_dict]
StackStats Tensor {"count": int, "samples": Tensor}
CountMeanStats Tensor {"count": int, "mean": Tensor}
CountMeanVarianceStats Tensor (fixed shape) {"count": int, "mean": Tensor, "variance": Tensor}

Usage Examples

import torch
from pyro.ops.streaming import (
    CountMeanVarianceStats,
    StatsOfDict,
    CountMeanStats,
)

# Track mean and variance of MCMC samples
stats = CountMeanVarianceStats()
for _ in range(1000):
    sample = torch.randn(10)
    stats.update(sample)

result = stats.get()
print(f"Count: {result['count']}")
print(f"Mean shape: {result['mean'].shape}")       # torch.Size([10])
print(f"Variance shape: {result['variance'].shape}")  # torch.Size([10])

# Merge statistics from two chains
chain1_stats = CountMeanVarianceStats()
chain2_stats = CountMeanVarianceStats()
for _ in range(500):
    chain1_stats.update(torch.randn(5))
    chain2_stats.update(torch.randn(5) + 1.0)

merged = chain1_stats.merge(chain2_stats)
print(merged.get()["count"])  # 1000

# Track different statistics per key
stats = StatsOfDict({"loc": CountMeanStats, "scale": CountMeanVarianceStats})
stats.update({"loc": torch.tensor(1.0), "scale": torch.tensor(0.5)})
summary = stats.get()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment