Implementation:Pyro ppl Pyro BroadcastMessenger
| Attribute | Value |
|---|---|
| File | pyro/poutine/broadcast_messenger.py
|
| Module | pyro.poutine.broadcast_messenger
|
| Lines | 94 |
| Parent Class | Messenger
|
| Purpose | Automatically broadcast distribution batch shapes to match enclosing plate contexts |
| License | Apache-2.0 (Uber Technologies, Inc.) |
Overview
BroadcastMessenger automatically expands the batch shape of distributions at sample sites to match the sizes declared by enclosing plate contexts. This makes it easy to write modular Pyro models where sub-components are agnostic of their wrapping plate contexts.
Without this handler, users must manually expand distributions to match the plate dimensions. With BroadcastMessenger, a scalar distribution like dist.Bernoulli(0.5) is automatically expanded to the correct shape based on the cond_indep_stack.
The handler implements a single static method _pyro_sample that:
- Reads the current
batch_shapefrom the distribution. - Iterates over the
cond_indep_stackframes to determine the target batch shape. - Expands the distribution to the target batch shape via
dist.expand(). - Preserves the
has_rsampleattribute from the original distribution.
Code Reference
class BroadcastMessenger(Messenger):
@staticmethod
@ignore_jit_warnings(["Converting a tensor to a Python boolean"])
def _pyro_sample(msg: "Message") -> None:
if (msg["done"] or msg["type"] != "sample"
or not isinstance(msg["fn"], TorchDistributionMixin)):
return
dist = msg["fn"]
actual_batch_shape = dist.batch_shape
target_batch_shape = [None if size == 1 else size for size in actual_batch_shape]
for f in msg["cond_indep_stack"]:
if f.dim is None or f.size == -1:
continue
assert f.dim < 0
# Extend target shape to accommodate the plate dimension
prefix_batch_shape = [None] * (-f.dim - len(target_batch_shape))
target_batch_shape = prefix_batch_shape + target_batch_shape
if (target_batch_shape[f.dim] is not None
and target_batch_shape[f.dim] != f.size):
raise ValueError("Shape mismatch inside plate('{}') ...".format(...))
target_batch_shape[f.dim] = f.size
# Fill remaining None entries
for i in range(-len(target_batch_shape) + 1, 1):
if target_batch_shape[i] is None:
target_batch_shape[i] = (
actual_batch_shape[i] if len(actual_batch_shape) >= -i else 1
)
msg["fn"] = dist.expand(target_batch_shape)
if msg["fn"].has_rsample != dist.has_rsample:
msg["fn"].has_rsample = dist.has_rsample
I/O Contract
| Input | Output |
|---|---|
A Message with type="sample", a TorchDistributionMixin in msg["fn"], and a cond_indep_stack
|
The same message with msg["fn"] expanded to match the target batch shape derived from enclosing plates
|
Note: This messenger has no constructor parameters. It is a stateless handler.
Usage Examples
Manual vs Automatic Broadcasting
# Without broadcast: must manually expand
def model_by_hand():
with pyro.plate("batch", 100, dim=-2):
with pyro.plate("components", 3, dim=-1):
sample = pyro.sample("sample",
dist.Bernoulli(torch.ones(3) * 0.5).expand_by(100))
assert sample.shape == torch.Size((100, 3))
# With broadcast: automatic expansion
@poutine.broadcast
def model_automatic():
with pyro.plate("batch", 100, dim=-2):
with pyro.plate("components", 3, dim=-1):
sample = pyro.sample("sample", dist.Bernoulli(torch.tensor(0.5)))
assert sample.shape == torch.Size((100, 3))
Using as Context Manager
with poutine.broadcast():
with pyro.plate("data", 1000):
x = pyro.sample("x", dist.Normal(0., 1.))
# x has shape (1000,) automatically
Related Pages
- Pyro_ppl_Pyro_Messenger_Base -- Parent class providing the handler protocol
- Pyro_ppl_Pyro_PlateMessenger -- Calls
BroadcastMessenger._pyro_sampledirectly in its_process_message - Pyro_ppl_Pyro_IndepMessenger -- Provides the
CondIndepStackFramedata used for shape inference - Pyro_ppl_Pyro_Poutine_Handlers -- The
poutine.broadcast()factory function