Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro BroadcastMessenger

From Leeroopedia
Revision as of 16:23, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Pyro_ppl_Pyro_BroadcastMessenger.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Attribute Value
File pyro/poutine/broadcast_messenger.py
Module pyro.poutine.broadcast_messenger
Lines 94
Parent Class Messenger
Purpose Automatically broadcast distribution batch shapes to match enclosing plate contexts
License Apache-2.0 (Uber Technologies, Inc.)

Overview

BroadcastMessenger automatically expands the batch shape of distributions at sample sites to match the sizes declared by enclosing plate contexts. This makes it easy to write modular Pyro models where sub-components are agnostic of their wrapping plate contexts.

Without this handler, users must manually expand distributions to match the plate dimensions. With BroadcastMessenger, a scalar distribution like dist.Bernoulli(0.5) is automatically expanded to the correct shape based on the cond_indep_stack.

The handler implements a single static method _pyro_sample that:

  1. Reads the current batch_shape from the distribution.
  2. Iterates over the cond_indep_stack frames to determine the target batch shape.
  3. Expands the distribution to the target batch shape via dist.expand().
  4. Preserves the has_rsample attribute from the original distribution.

Code Reference

class BroadcastMessenger(Messenger):
    @staticmethod
    @ignore_jit_warnings(["Converting a tensor to a Python boolean"])
    def _pyro_sample(msg: "Message") -> None:
        if (msg["done"] or msg["type"] != "sample"
                or not isinstance(msg["fn"], TorchDistributionMixin)):
            return

        dist = msg["fn"]
        actual_batch_shape = dist.batch_shape
        target_batch_shape = [None if size == 1 else size for size in actual_batch_shape]

        for f in msg["cond_indep_stack"]:
            if f.dim is None or f.size == -1:
                continue
            assert f.dim < 0
            # Extend target shape to accommodate the plate dimension
            prefix_batch_shape = [None] * (-f.dim - len(target_batch_shape))
            target_batch_shape = prefix_batch_shape + target_batch_shape
            if (target_batch_shape[f.dim] is not None
                    and target_batch_shape[f.dim] != f.size):
                raise ValueError("Shape mismatch inside plate('{}') ...".format(...))
            target_batch_shape[f.dim] = f.size

        # Fill remaining None entries
        for i in range(-len(target_batch_shape) + 1, 1):
            if target_batch_shape[i] is None:
                target_batch_shape[i] = (
                    actual_batch_shape[i] if len(actual_batch_shape) >= -i else 1
                )

        msg["fn"] = dist.expand(target_batch_shape)
        if msg["fn"].has_rsample != dist.has_rsample:
            msg["fn"].has_rsample = dist.has_rsample

I/O Contract

Input Output
A Message with type="sample", a TorchDistributionMixin in msg["fn"], and a cond_indep_stack The same message with msg["fn"] expanded to match the target batch shape derived from enclosing plates

Note: This messenger has no constructor parameters. It is a stateless handler.

Usage Examples

Manual vs Automatic Broadcasting

# Without broadcast: must manually expand
def model_by_hand():
    with pyro.plate("batch", 100, dim=-2):
        with pyro.plate("components", 3, dim=-1):
            sample = pyro.sample("sample",
                dist.Bernoulli(torch.ones(3) * 0.5).expand_by(100))
            assert sample.shape == torch.Size((100, 3))

# With broadcast: automatic expansion
@poutine.broadcast
def model_automatic():
    with pyro.plate("batch", 100, dim=-2):
        with pyro.plate("components", 3, dim=-1):
            sample = pyro.sample("sample", dist.Bernoulli(torch.tensor(0.5)))
            assert sample.shape == torch.Size((100, 3))

Using as Context Manager

with poutine.broadcast():
    with pyro.plate("data", 1000):
        x = pyro.sample("x", dist.Normal(0., 1.))
        # x has shape (1000,) automatically

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment