Implementation:Pyro ppl Pyro Messenger Base
Appearance
| Attribute | Value |
|---|---|
| File | pyro/poutine/messenger.py
|
| Module | pyro.poutine.messenger
|
| Lines | 288 |
| Purpose | Base class for all Poutine effect handlers (messengers) |
| Architecture Role | Root of the Messenger class hierarchy; defines the handler protocol |
| License | Apache-2.0 (Uber Technologies, Inc.) |
Overview
Messenger is the base class for all Poutine effect handlers. It defines the protocol by which handlers interact with the Pyro runtime stack and modify the behavior of stochastic functions.
The base Messenger class implements:
- Context manager protocol --
__enter__pushes the messenger onto_PYRO_STACK;__exit__removes it, with proper exception handling that removes the messenger and everything below it if an error occurs. - Callable protocol --
__call__(fn)wraps a function so it executes within the messenger's context. - Message processing --
_process_message(msg)dispatches to_pyro_{type}methods (e.g.,_pyro_sample,_pyro_param);_postprocess_message(msg)dispatches to_pyro_post_{type}methods. - Dynamic registration --
register()andunregister()class methods allow dynamically adding operations to messenger classes.
The module also provides:
- _bound_partial -- A helper class that supports using class methods as handler arguments.
- unwrap(fn) -- Recursively unwraps poutine-wrapped functions to find the original callable.
- block_messengers(predicate) -- An experimental context manager for temporarily removing matching messengers from the stack.
Code Reference
Messenger Class
class Messenger:
def __call__(self, fn):
if not callable(fn):
raise ValueError(f"{fn!r} is not callable, ...")
wraps = _bound_partial(partial(_context_wrap, self, fn))
return wraps
def __enter__(self):
if self not in _PYRO_STACK:
_PYRO_STACK.append(self)
return self
else:
raise ValueError("cannot install a Messenger instance twice")
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
if _PYRO_STACK[-1] == self:
_PYRO_STACK.pop()
else:
raise ValueError("This Messenger is not on the bottom of the stack")
else:
if self in _PYRO_STACK:
loc = _PYRO_STACK.index(self)
for i in range(loc, len(_PYRO_STACK)):
_PYRO_STACK.pop()
def _process_message(self, msg):
method = getattr(self, f"_pyro_{msg['type']}", None)
if method is not None:
method(msg)
def _postprocess_message(self, msg):
method = getattr(self, f"_pyro_post_{msg['type']}", None)
if method is not None:
method(msg)
def _reset(self):
pass
Dynamic Registration
@classmethod
def register(cls, fn=None, type=None, post=None):
"""Dynamically add operations to an effect."""
setattr(cls, "_pyro_" + ("post_" if post else "") + type, staticmethod(fn))
return fn
@classmethod
def unregister(cls, fn=None, type=None):
"""Dynamically remove operations from an effect."""
delattr(cls, "_pyro_post_" + type)
delattr(cls, "_pyro_" + type)
return fn
block_messengers
@contextmanager
def block_messengers(predicate):
"""
EXPERIMENTAL Context manager to temporarily remove matching messengers
from _PYRO_STACK. Does not call __exit__() and __enter__().
"""
blocked = {}
try:
for i, messenger in enumerate(_PYRO_STACK):
if predicate(messenger):
blocked[i] = messenger
_PYRO_STACK[i] = Messenger() # trivial messenger
yield list(blocked.values())
finally:
for i, messenger in blocked.items():
_PYRO_STACK[i] = messenger
Helper Functions
def unwrap(fn: Callable) -> Callable:
"""Recursively unwraps poutines to find the original callable."""
while True:
if isinstance(fn, _bound_partial):
fn = fn.func
continue
if isinstance(fn, partial) and len(fn.args) >= 2:
fn = fn.args[1]
continue
return fn
I/O Contract
| Method | Input | Output |
|---|---|---|
| __call__(fn) | A callable fn
|
A wrapped callable that executes fn within this messenger's context
|
| __enter__() | (none) | self (pushes onto _PYRO_STACK)
|
| __exit__(exc_type, exc_value, traceback) | Exception info (or all None)
|
None (pops from _PYRO_STACK)
|
| _process_message(msg) | A Message dict
|
None (dispatches to _pyro_{type} method)
|
| _postprocess_message(msg) | A Message dict
|
None (dispatches to _pyro_post_{type} method)
|
| register(fn, type, post) | A function, type string, and optional post flag | The registered function (also sets class attribute) |
| unwrap(fn) | A possibly wrapped callable | The original unwrapped callable |
Usage Examples
Subclassing Messenger
class MyMessenger(Messenger):
def _pyro_sample(self, msg):
# Process sample sites
print(f"Sampling {msg['name']} from {msg['fn']}")
def _pyro_post_sample(self, msg):
# Post-process after sampling
print(f"Sampled {msg['name']} = {msg['value']}")
with MyMessenger():
z = pyro.sample("z", dist.Normal(0, 1))
Dynamic Registration
@MyMessenger.register(type="sample")
def custom_sample_handler(msg):
# Custom sample handling logic
pass
Temporarily Blocking Messengers
from pyro.poutine.messenger import block_messengers
with block_messengers(lambda m: isinstance(m, ScaleMessenger)) as blocked:
# ScaleMessenger is temporarily removed
z = pyro.sample("z", dist.Normal(0, 1))
Related Pages
- Pyro_ppl_Pyro_Poutine_Runtime -- The
_PYRO_STACKandapply_stackthat invoke messenger methods - Pyro_ppl_Pyro_Poutine_Handlers -- Factory functions that instantiate and compose messengers
- Pyro_ppl_Pyro_BlockMessenger -- Example subclass: blocks messages by setting
msg["stop"] - Pyro_ppl_Pyro_ConditionMessenger -- Example subclass: conditions sample sites
- Pyro_ppl_Pyro_PlateMessenger -- Example subclass: implements plate semantics
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment