Implementation:Pyro ppl Pyro BlockMessenger
| Attribute | Value |
|---|---|
| File | pyro/poutine/block_messenger.py
|
| Module | pyro.poutine.block_messenger
|
| Lines | 170 |
| Parent Class | Messenger
|
| Purpose | Selectively hide Pyro primitive sites from outer effect handlers |
| License | Apache-2.0 (Uber Technologies, Inc.) |
Overview
BlockMessenger selectively hides Pyro primitive sites from the outside world by setting msg["stop"] = True on messages that match the hiding criteria. When a message is stopped, no further handlers above this one on the stack will process it.
Default behavior is to block everything. Sites can be selectively hidden or exposed based on:
- hide_fn / expose_fn -- Custom callable predicates
- hide / expose -- Lists of site names
- hide_types / expose_types -- Lists of site types (e.g.,
"sample","param","observe") - hide_all / expose_all -- Boolean flags for blanket hiding/exposing
The module also includes helper functions _block_fn (the decision logic), _make_default_hide_fn (builds the hide predicate from parameters), and _negate_fn (converts an expose function into a hide function).
Code Reference
BlockMessenger Class
class BlockMessenger(Messenger):
def __init__(
self,
hide_fn=None, expose_fn=None,
hide_all=True, expose_all=False,
hide=None, expose=None,
hide_types=None, expose_types=None,
):
super().__init__()
if not (hide_fn is None or expose_fn is None):
raise ValueError("Only specify one of hide_fn or expose_fn")
if hide_fn is not None:
self.hide_fn = hide_fn
elif expose_fn is not None:
self.hide_fn = _negate_fn(expose_fn)
else:
self.hide_fn = _make_default_hide_fn(
hide_all, expose_all, hide, expose, hide_types, expose_types
)
def _process_message(self, msg):
msg["stop"] = bool(self.hide_fn(msg))
_block_fn Decision Logic
def _block_fn(expose, expose_types, hide, hide_types, hide_all, msg):
# handle observes
if msg["type"] == "sample" and msg["is_observed"]:
msg_type = "observe"
else:
msg_type = msg["type"]
is_not_exposed = (msg["name"] not in expose) and (msg_type not in expose_types)
if (msg["name"] in hide) or (msg_type in hide_types) or (is_not_exposed and hide_all):
return True
else:
return False
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| hide_fn | Callable[[Message], Optional[bool]]
|
Custom function returning True to hide a site
|
| expose_fn | Callable[[Message], Optional[bool]]
|
Custom function returning True to expose a site
|
| hide | List[str] or None
|
List of site names to hide |
| expose | List[str] or None
|
List of site names to expose (all others hidden) |
| hide_types | List[str] or None
|
List of site types to hide |
| expose_types | List[str] or None
|
List of site types to expose (all others hidden) |
| hide_all | bool (default True)
|
Hide all sites |
| expose_all | bool (default False)
|
Expose all sites |
Effect on messages: Sets msg["stop"] = True for hidden sites, preventing outer handlers from seeing them.
Usage Examples
Hide Specific Sites
def fn():
a = pyro.sample("a", dist.Normal(0., 1.))
return pyro.sample("b", dist.Normal(a, 1.))
fn_inner = pyro.poutine.trace(fn)
fn_outer = pyro.poutine.trace(pyro.poutine.block(fn_inner, hide=["a"]))
trace_inner = fn_inner.get_trace()
trace_outer = fn_outer.get_trace()
assert "a" in trace_inner # True
assert "a" not in trace_outer # True (hidden from outer trace)
assert "b" in trace_outer # True (not hidden)
Expose Only Specific Sites
# Only expose site "b", hiding everything else
blocked_fn = pyro.poutine.block(fn, expose=["b"])
Using Custom Hide Function
# Hide all observed sites
blocked_fn = pyro.poutine.block(fn, hide_fn=lambda msg: msg["is_observed"])
Related Pages
- Pyro_ppl_Pyro_Messenger_Base -- Parent class providing the handler protocol
- Pyro_ppl_Pyro_Poutine_Handlers -- The
poutine.block()factory function - Pyro_ppl_Pyro_Poutine_Runtime --
apply_stackchecksmsg["stop"]to halt processing - Pyro_ppl_Pyro_EscapeMessenger -- Uses
BlockMessengerinNonlocalExit.reset_stack()