Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro SubsampleMessenger

From Leeroopedia
Revision as of 16:25, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Pyro_ppl_Pyro_SubsampleMessenger.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Attribute Value
File pyro/poutine/subsample_messenger.py
Module pyro.poutine.subsample_messenger
Lines 218
Parent Class IndepMessenger
Purpose Extend IndepMessenger with data subsampling and automatic scaling
Architecture Role Intermediate class between IndepMessenger and PlateMessenger; handles subsampling logic
License Apache-2.0 (Uber Technologies, Inc.)

Overview

SubsampleMessenger extends IndepMessenger with data subsampling capabilities. It randomly selects a subsample of indices from a range and automatically scales log probabilities by size / subsample_size to provide unbiased gradient estimates.

The module defines two classes:

  • _Subsample -- An internal Distribution that generates random subsample indices. It returns a random permutation of indices (truncated to subsample_size) and has zero log probability so that the plate can provide the unbiased estimate.
  • SubsampleMessenger -- Extends IndepMessenger with:
    • Subsample generation via _Subsample distribution (processed through the stack via apply_stack).
    • Scale correction: msg["scale"] *= size / subsample_size.
    • Post-processing of param and subsample messages to index-select parameters along the plate dimension.

The post-processing logic handles parameter subsampling by calling value.index_select(dim, indices) and tracking the subsample indices on the unconstrained parameter via param._pyro_subsample.

Code Reference

_Subsample Distribution

class _Subsample(Distribution):
    """Internal distribution for random subsample generation."""

    def __init__(self, size, subsample_size, use_cuda=None, device=None):
        self.size = size
        self.subsample_size = subsample_size
        self.device = device or torch.tensor(tuple()).device

    def sample(self, sample_shape=torch.Size()):
        if subsample_size is None or subsample_size >= self.size:
            return torch.arange(self.size, device=self.device)
        return torch.randperm(self.size, device=self.device)[:subsample_size].clone()

    def log_prob(self, x):
        return torch.tensor(0.0, device=self.device)  # zero for unbiased estimation

SubsampleMessenger

class SubsampleMessenger(IndepMessenger):
    def __init__(self, name, size=None, subsample_size=None,
                 subsample=None, dim=None, use_cuda=None, device=None):
        full_size, self.subsample_size, subsample = self._subsample(
            name, size, subsample_size, subsample, use_cuda, device)
        super().__init__(name, full_size, dim, device)
        self._indices = subsample

    def _process_message(self, msg):
        frame = CondIndepStackFrame(
            name=self.name, dim=self.dim,
            size=self.subsample_size, counter=self.counter,
            full_size=self.size)
        msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"]
        msg["scale"] = msg["scale"] * self.size / self.subsample_size

    def _postprocess_message(self, msg):
        if msg["type"] in ("param", "subsample") and self.dim is not None:
            event_dim = msg["kwargs"].get("event_dim")
            if event_dim is not None:
                dim = self.dim - event_dim
                shape = msg["value"].shape
                if len(shape) >= -dim and shape[dim] != 1:
                    if self.subsample_size < self.size:
                        # Subsample parameters with known batch semantics
                        msg["value"] = msg["value"].index_select(dim, self._indices)

I/O Contract

Parameter Type Description
name str Name of the plate context
size Optional[int] Total size of the data
subsample_size Optional[int] Subsample size (if None or >= size, no subsampling)
subsample Optional[torch.Tensor] Explicit subsample indices
dim Optional[int] Batch dimension (negative integer)
device Optional[str] Device for index tensors
Message Effect Description
msg["cond_indep_stack"] Prepends CondIndepStackFrame with subsample_size and full_size
msg["scale"] Multiplied by size / subsample_size for unbiased estimation
msg["value"] (param/subsample) Index-selected along plate dimension for subsampled parameters

Usage Examples

Plate with Subsampling

data = torch.randn(10000)

def model(data):
    loc = pyro.sample("loc", dist.Normal(0, 1))
    # Subsample 100 from 10000 items; log prob scaled by 100x
    with pyro.plate("data", len(data), subsample_size=100) as idx:
        pyro.sample("obs", dist.Normal(loc, 1), obs=data[idx])

Explicit Subsample Indices

indices = torch.tensor([0, 5, 10, 15, 20])
with pyro.plate("data", 100, subsample=indices) as idx:
    pyro.sample("obs", dist.Normal(0, 1), obs=data[idx])

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment