Implementation:Pyro ppl Pyro PlateMessenger
Appearance
| Attribute | Value |
|---|---|
| File | pyro/poutine/plate_messenger.py
|
| Module | pyro.poutine.plate_messenger
|
| Lines | 91 |
| Parent Class | SubsampleMessenger (which extends IndepMessenger)
|
| Purpose | Combined shape inference, independence annotation, and subsampling for pyro.plate |
| Architecture Role | The messenger underlying the pyro.plate primitive
|
| License | Apache-2.0 (Uber Technologies, Inc.) |
Overview
PlateMessenger is the swiss army knife of the Pyro plate system, combining three capabilities:
- Independence annotation -- Inherited from
IndepMessenger, it tracks conditional independence viaCondIndepStackFramein each message'scond_indep_stack. - Subsampling -- Inherited from
SubsampleMessenger, it supports data subsampling with appropriate scaling of log probabilities. - Broadcasting -- It directly calls
BroadcastMessenger._pyro_sampleto automatically expand distribution batch shapes to match the plate dimensions.
The module also provides block_plate, an experimental context manager for temporarily blocking a single enclosing plate. This is useful for sampling global variables inside a plated context.
Code Reference
PlateMessenger
class PlateMessenger(SubsampleMessenger):
"""Swiss army knife: combines shape inference, independence annotation, and subsampling."""
def _process_message(self, msg: "Message") -> None:
super()._process_message(msg)
BroadcastMessenger._pyro_sample(msg)
def __enter__(self) -> Optional["torch.Tensor"]:
super().__enter__()
if self._vectorized and self._indices is not None:
return self.indices
return None
block_plate
@contextmanager
def block_plate(name=None, dim=None, *, strict=True):
"""
EXPERIMENTAL Context manager to temporarily block a single enclosing plate.
Useful for sampling auxiliary/global variables inside a plated context.
"""
if (name is not None) == (dim is not None):
raise ValueError("Exactly one of name,dim must be specified")
def predicate(messenger):
if not isinstance(messenger, PlateMessenger):
return False
if name is not None:
return messenger.name == name
if dim is not None:
return messenger.dim == dim
raise ValueError("Unreachable")
with block_messengers(predicate) as matches:
if strict and len(matches) != 1:
raise ValueError(f"block_plate matched {len(matches)} messengers. ...")
yield
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| name | str
|
Name of the plate context |
| size | Optional[int]
|
Total size of the data |
| subsample_size | Optional[int]
|
Size of each subsample (for data subsampling) |
| subsample | Optional[torch.Tensor]
|
Explicit subsample indices |
| dim | Optional[int]
|
Batch dimension (negative integer) |
| device | Optional[str]
|
Device for index tensors |
| Message Effect | Description |
|---|---|
| msg["cond_indep_stack"] | Prepends a CondIndepStackFrame with plate metadata (from SubsampleMessenger)
|
| msg["scale"] | Multiplied by size / subsample_size for unbiased estimation (from SubsampleMessenger)
|
| msg["fn"] | Distribution batch shape expanded to match plate dimensions (from BroadcastMessenger)
|
Return value of __enter__: Returns the subsample indices tensor if vectorized, None otherwise.
Usage Examples
Basic Plate Usage
def model(data):
loc = pyro.sample("loc", dist.Normal(0, 1))
with pyro.plate("data", len(data)):
pyro.sample("obs", dist.Normal(loc, 1), obs=data)
Plate with Subsampling
def model(data):
loc = pyro.sample("loc", dist.Normal(0, 1))
with pyro.plate("data", len(data), subsample_size=64) as idx:
pyro.sample("obs", dist.Normal(loc, 1), obs=data[idx])
Using block_plate
def model(data):
loc = pyro.sample("loc", dist.Normal(0, 1))
with pyro.plate("data", len(data)):
with block_plate("data"):
# Sample a global variable inside a plate
scale = pyro.sample("scale", dist.LogNormal(0, 1))
pyro.sample("x", dist.Normal(loc, scale))
Related Pages
- Pyro_ppl_Pyro_SubsampleMessenger -- Direct parent class providing subsampling
- Pyro_ppl_Pyro_IndepMessenger -- Grandparent class providing independence annotation
- Pyro_ppl_Pyro_BroadcastMessenger -- Called directly for automatic batch shape expansion
- Pyro_ppl_Pyro_Messenger_Base --
block_messengersused byblock_plate - Pyro_ppl_Pyro_Poutine_Runtime --
_DIM_ALLOCATORfor plate dimension management
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment