Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro PlateMessenger

From Leeroopedia
Revision as of 16:24, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Pyro_ppl_Pyro_PlateMessenger.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Attribute Value
File pyro/poutine/plate_messenger.py
Module pyro.poutine.plate_messenger
Lines 91
Parent Class SubsampleMessenger (which extends IndepMessenger)
Purpose Combined shape inference, independence annotation, and subsampling for pyro.plate
Architecture Role The messenger underlying the pyro.plate primitive
License Apache-2.0 (Uber Technologies, Inc.)

Overview

PlateMessenger is the swiss army knife of the Pyro plate system, combining three capabilities:

  1. Independence annotation -- Inherited from IndepMessenger, it tracks conditional independence via CondIndepStackFrame in each message's cond_indep_stack.
  2. Subsampling -- Inherited from SubsampleMessenger, it supports data subsampling with appropriate scaling of log probabilities.
  3. Broadcasting -- It directly calls BroadcastMessenger._pyro_sample to automatically expand distribution batch shapes to match the plate dimensions.

The module also provides block_plate, an experimental context manager for temporarily blocking a single enclosing plate. This is useful for sampling global variables inside a plated context.

Code Reference

PlateMessenger

class PlateMessenger(SubsampleMessenger):
    """Swiss army knife: combines shape inference, independence annotation, and subsampling."""

    def _process_message(self, msg: "Message") -> None:
        super()._process_message(msg)
        BroadcastMessenger._pyro_sample(msg)

    def __enter__(self) -> Optional["torch.Tensor"]:
        super().__enter__()
        if self._vectorized and self._indices is not None:
            return self.indices
        return None

block_plate

@contextmanager
def block_plate(name=None, dim=None, *, strict=True):
    """
    EXPERIMENTAL Context manager to temporarily block a single enclosing plate.

    Useful for sampling auxiliary/global variables inside a plated context.
    """
    if (name is not None) == (dim is not None):
        raise ValueError("Exactly one of name,dim must be specified")

    def predicate(messenger):
        if not isinstance(messenger, PlateMessenger):
            return False
        if name is not None:
            return messenger.name == name
        if dim is not None:
            return messenger.dim == dim
        raise ValueError("Unreachable")

    with block_messengers(predicate) as matches:
        if strict and len(matches) != 1:
            raise ValueError(f"block_plate matched {len(matches)} messengers. ...")
        yield

I/O Contract

Parameter Type Description
name str Name of the plate context
size Optional[int] Total size of the data
subsample_size Optional[int] Size of each subsample (for data subsampling)
subsample Optional[torch.Tensor] Explicit subsample indices
dim Optional[int] Batch dimension (negative integer)
device Optional[str] Device for index tensors
Message Effect Description
msg["cond_indep_stack"] Prepends a CondIndepStackFrame with plate metadata (from SubsampleMessenger)
msg["scale"] Multiplied by size / subsample_size for unbiased estimation (from SubsampleMessenger)
msg["fn"] Distribution batch shape expanded to match plate dimensions (from BroadcastMessenger)

Return value of __enter__: Returns the subsample indices tensor if vectorized, None otherwise.

Usage Examples

Basic Plate Usage

def model(data):
    loc = pyro.sample("loc", dist.Normal(0, 1))
    with pyro.plate("data", len(data)):
        pyro.sample("obs", dist.Normal(loc, 1), obs=data)

Plate with Subsampling

def model(data):
    loc = pyro.sample("loc", dist.Normal(0, 1))
    with pyro.plate("data", len(data), subsample_size=64) as idx:
        pyro.sample("obs", dist.Normal(loc, 1), obs=data[idx])

Using block_plate

def model(data):
    loc = pyro.sample("loc", dist.Normal(0, 1))
    with pyro.plate("data", len(data)):
        with block_plate("data"):
            # Sample a global variable inside a plate
            scale = pyro.sample("scale", dist.LogNormal(0, 1))
        pyro.sample("x", dist.Normal(loc, scale))

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment