Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro IndepMessenger

From Leeroopedia


Attribute Value
File pyro/poutine/indep_messenger.py
Module pyro.poutine.indep_messenger
Lines 148
Parent Class Messenger
Purpose Track conditional independence information (plate contexts) for sample sites
Architecture Role Core building block for plate semantics; parent of SubsampleMessenger and PlateMessenger
License Apache-2.0 (Uber Technologies, Inc.)

Overview

IndepMessenger is the base messenger for conditional independence annotation. It manages a stack of independence information declared by nested plate contexts, storing CondIndepStackFrame tuples in the cond_indep_stack field of each message.

The module defines two key types:

  • CondIndepStackFrame -- A NamedTuple recording the name, dimension, size, counter, and optional full size of a plate context. It includes a vectorized property indicating whether the plate uses a batch dimension.
  • IndepMessenger -- A Messenger that allocates plate dimensions (via _DIM_ALLOCATOR), supports both vectorized (batch dimension) and sequential (iterator) usage patterns, and prepends a CondIndepStackFrame to each message's cond_indep_stack.

The messenger supports three usage patterns:

  1. Context manager (vectorized) -- Allocates a batch dimension and applies to all enclosed sites.
  2. Iterator (sequential) -- Iterates over indices without allocating a batch dimension.
  3. Both are mutually exclusive for a given instance.

Code Reference

CondIndepStackFrame

class CondIndepStackFrame(NamedTuple):
    name: str
    dim: Optional[int]
    size: int
    counter: int
    full_size: Optional[int] = None

    @property
    def vectorized(self) -> bool:
        return self.dim is not None

IndepMessenger

class IndepMessenger(Messenger):
    def __init__(self, name, size, dim=None, device=None):
        if not torch._C._get_tracing_state() and size == 0:
            raise ZeroDivisionError("size cannot be zero")
        super().__init__()
        self._vectorized = None
        if dim is not None:
            self._vectorized = True
        self.name = name
        self.dim = dim
        self.size = size
        self.device = device
        self.counter = 0

    def __enter__(self):
        if self._vectorized is not False:
            self._vectorized = True
        if self._vectorized is True:
            self.dim = _DIM_ALLOCATOR.allocate(self.name, self.dim)
        return super().__enter__()

    def __exit__(self, *args):
        if self._vectorized is True:
            _DIM_ALLOCATOR.free(self.name, self.dim)
        return super().__exit__(*args)

    def __iter__(self):
        """Sequential (non-vectorized) iteration over plate indices."""
        self._vectorized = False
        self.dim = None
        for i in self.indices:
            self.next_context()
            with self:
                yield i.item() if isinstance(i, torch.Tensor) else i

    def _process_message(self, msg):
        frame = CondIndepStackFrame(self.name, self.dim, self.size, self.counter)
        msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"]

I/O Contract

Parameter Type Description
name str Name of the independence context
size int Size of the independent dimension
dim Optional[int] Batch dimension index (negative integer); None for auto-allocation or sequential use
device Optional[str] Device for index tensors
Message Effect Description
msg["cond_indep_stack"] Prepends a CondIndepStackFrame tuple with the plate's metadata

Usage Examples

Vectorized (Batch Dimension) Usage

x_axis = IndepMessenger('outer', 320, dim=-1)
y_axis = IndepMessenger('inner', 200, dim=-2)

with x_axis:
    x_noise = pyro.sample("x_noise", dist.Normal(loc, scale).expand([320]))

with y_axis:
    y_noise = pyro.sample("y_noise", dist.Normal(loc, scale).expand([200, 1]))

with x_axis, y_axis:
    xy_noise = pyro.sample("xy_noise", dist.Normal(loc, scale).expand([200, 320]))

Sequential (Iterator) Usage

plate = IndepMessenger('data', len(data))
for i in plate:
    pyro.sample(f"obs_{i}", dist.Normal(0, 1), obs=data[i])

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment