Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro EscapeMessenger

From Leeroopedia


Attribute Value
File pyro/poutine/escape_messenger.py
Module pyro.poutine.escape_messenger
Lines 43
Parent Class Messenger
Purpose Perform nonlocal exits from poutine execution at specific sample sites
License Apache-2.0 (Uber Technologies, Inc.)

Overview

EscapeMessenger performs a nonlocal exit from poutine execution by raising a NonlocalExit exception when a user-specified predicate evaluates to True at a sample site.

This handler is used primarily in the poutine.queue composite operation for sequential enumeration over discrete variables. When the escape condition is met at a sample site:

  1. The message is marked as done (msg["done"] = True).
  2. The message is stopped (msg["stop"] = True).
  3. A continuation function is set that raises NonlocalExit(msg).

The NonlocalExit exception carries the site information, enabling the caller (e.g., poutine.queue) to capture partial traces and extend them.

Code Reference

class EscapeMessenger(Messenger):
    def __init__(self, escape_fn: Callable[[Message], bool]) -> None:
        super().__init__()
        self.escape_fn = escape_fn

    def _pyro_sample(self, msg: Message) -> None:
        if self.escape_fn(msg):
            msg["done"] = True
            msg["stop"] = True

            def cont(m: Message) -> None:
                raise NonlocalExit(m)

            msg["continuation"] = cont

I/O Contract

Parameter Type Description
escape_fn Callable[[Message], bool] Predicate that returns True to trigger a nonlocal exit at the given site
Message Effect Description
msg["done"] Set to True when escape condition met
msg["stop"] Set to True to prevent further processing
msg["continuation"] Set to a function that raises NonlocalExit(msg)

Exception raised: NonlocalExit (defined in pyro.poutine.runtime), carrying the site message.

Usage Examples

Using with poutine.queue for Enumeration

# EscapeMessenger is primarily used internally by poutine.queue
# for sequential enumeration over discrete variables.

from pyro.poutine.runtime import NonlocalExit

try:
    traced = poutine.trace(
        poutine.escape(
            poutine.replay(model, trace=partial_trace),
            escape_fn=functools.partial(discrete_escape, partial_trace),
        )
    )
    result = traced(*args, **kwargs)
except NonlocalExit as site_container:
    site_container.reset_stack()
    # Extend partial traces and add to queue
    for tr in extend_fn(traced.trace.copy(), site_container.site):
        queue.put(tr)

Custom Escape Function

# Escape at the first unobserved discrete sample site
def my_escape_fn(msg):
    return (msg["type"] == "sample"
            and not msg["is_observed"]
            and hasattr(msg["fn"], "enumerate_support"))

escaped_model = poutine.escape(model, escape_fn=my_escape_fn)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment