Implementation:Pyro ppl Pyro EscapeMessenger
| Attribute | Value |
|---|---|
| File | pyro/poutine/escape_messenger.py
|
| Module | pyro.poutine.escape_messenger
|
| Lines | 43 |
| Parent Class | Messenger
|
| Purpose | Perform nonlocal exits from poutine execution at specific sample sites |
| License | Apache-2.0 (Uber Technologies, Inc.) |
Overview
EscapeMessenger performs a nonlocal exit from poutine execution by raising a NonlocalExit exception when a user-specified predicate evaluates to True at a sample site.
This handler is used primarily in the poutine.queue composite operation for sequential enumeration over discrete variables. When the escape condition is met at a sample site:
- The message is marked as done (
msg["done"] = True). - The message is stopped (
msg["stop"] = True). - A continuation function is set that raises
NonlocalExit(msg).
The NonlocalExit exception carries the site information, enabling the caller (e.g., poutine.queue) to capture partial traces and extend them.
Code Reference
class EscapeMessenger(Messenger):
def __init__(self, escape_fn: Callable[[Message], bool]) -> None:
super().__init__()
self.escape_fn = escape_fn
def _pyro_sample(self, msg: Message) -> None:
if self.escape_fn(msg):
msg["done"] = True
msg["stop"] = True
def cont(m: Message) -> None:
raise NonlocalExit(m)
msg["continuation"] = cont
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| escape_fn | Callable[[Message], bool]
|
Predicate that returns True to trigger a nonlocal exit at the given site
|
| Message Effect | Description |
|---|---|
| msg["done"] | Set to True when escape condition met
|
| msg["stop"] | Set to True to prevent further processing
|
| msg["continuation"] | Set to a function that raises NonlocalExit(msg)
|
Exception raised: NonlocalExit (defined in pyro.poutine.runtime), carrying the site message.
Usage Examples
Using with poutine.queue for Enumeration
# EscapeMessenger is primarily used internally by poutine.queue
# for sequential enumeration over discrete variables.
from pyro.poutine.runtime import NonlocalExit
try:
traced = poutine.trace(
poutine.escape(
poutine.replay(model, trace=partial_trace),
escape_fn=functools.partial(discrete_escape, partial_trace),
)
)
result = traced(*args, **kwargs)
except NonlocalExit as site_container:
site_container.reset_stack()
# Extend partial traces and add to queue
for tr in extend_fn(traced.trace.copy(), site_container.site):
queue.put(tr)
Custom Escape Function
# Escape at the first unobserved discrete sample site
def my_escape_fn(msg):
return (msg["type"] == "sample"
and not msg["is_observed"]
and hasattr(msg["fn"], "enumerate_support"))
escaped_model = poutine.escape(model, escape_fn=my_escape_fn)
Related Pages
- Pyro_ppl_Pyro_Messenger_Base -- Parent class providing the handler protocol
- Pyro_ppl_Pyro_Poutine_Runtime -- Defines
NonlocalExitexception and continuation handling inapply_stack - Pyro_ppl_Pyro_Poutine_Handlers -- The
poutine.escape()factory andpoutine.queue()composite operation - Pyro_ppl_Pyro_BlockMessenger --
NonlocalExit.reset_stack()usesBlockMessengerto reset state