Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro BlockMessenger

From Leeroopedia


Attribute Value
File pyro/poutine/block_messenger.py
Module pyro.poutine.block_messenger
Lines 170
Parent Class Messenger
Purpose Selectively hide Pyro primitive sites from outer effect handlers
License Apache-2.0 (Uber Technologies, Inc.)

Overview

BlockMessenger selectively hides Pyro primitive sites from the outside world by setting msg["stop"] = True on messages that match the hiding criteria. When a message is stopped, no further handlers above this one on the stack will process it.

Default behavior is to block everything. Sites can be selectively hidden or exposed based on:

  • hide_fn / expose_fn -- Custom callable predicates
  • hide / expose -- Lists of site names
  • hide_types / expose_types -- Lists of site types (e.g., "sample", "param", "observe")
  • hide_all / expose_all -- Boolean flags for blanket hiding/exposing

The module also includes helper functions _block_fn (the decision logic), _make_default_hide_fn (builds the hide predicate from parameters), and _negate_fn (converts an expose function into a hide function).

Code Reference

BlockMessenger Class

class BlockMessenger(Messenger):
    def __init__(
        self,
        hide_fn=None, expose_fn=None,
        hide_all=True, expose_all=False,
        hide=None, expose=None,
        hide_types=None, expose_types=None,
    ):
        super().__init__()
        if not (hide_fn is None or expose_fn is None):
            raise ValueError("Only specify one of hide_fn or expose_fn")
        if hide_fn is not None:
            self.hide_fn = hide_fn
        elif expose_fn is not None:
            self.hide_fn = _negate_fn(expose_fn)
        else:
            self.hide_fn = _make_default_hide_fn(
                hide_all, expose_all, hide, expose, hide_types, expose_types
            )

    def _process_message(self, msg):
        msg["stop"] = bool(self.hide_fn(msg))

_block_fn Decision Logic

def _block_fn(expose, expose_types, hide, hide_types, hide_all, msg):
    # handle observes
    if msg["type"] == "sample" and msg["is_observed"]:
        msg_type = "observe"
    else:
        msg_type = msg["type"]

    is_not_exposed = (msg["name"] not in expose) and (msg_type not in expose_types)

    if (msg["name"] in hide) or (msg_type in hide_types) or (is_not_exposed and hide_all):
        return True
    else:
        return False

I/O Contract

Parameter Type Description
hide_fn Callable[[Message], Optional[bool]] Custom function returning True to hide a site
expose_fn Callable[[Message], Optional[bool]] Custom function returning True to expose a site
hide List[str] or None List of site names to hide
expose List[str] or None List of site names to expose (all others hidden)
hide_types List[str] or None List of site types to hide
expose_types List[str] or None List of site types to expose (all others hidden)
hide_all bool (default True) Hide all sites
expose_all bool (default False) Expose all sites

Effect on messages: Sets msg["stop"] = True for hidden sites, preventing outer handlers from seeing them.

Usage Examples

Hide Specific Sites

def fn():
    a = pyro.sample("a", dist.Normal(0., 1.))
    return pyro.sample("b", dist.Normal(a, 1.))

fn_inner = pyro.poutine.trace(fn)
fn_outer = pyro.poutine.trace(pyro.poutine.block(fn_inner, hide=["a"]))
trace_inner = fn_inner.get_trace()
trace_outer = fn_outer.get_trace()
assert "a" in trace_inner     # True
assert "a" not in trace_outer  # True (hidden from outer trace)
assert "b" in trace_outer      # True (not hidden)

Expose Only Specific Sites

# Only expose site "b", hiding everything else
blocked_fn = pyro.poutine.block(fn, expose=["b"])

Using Custom Hide Function

# Hide all observed sites
blocked_fn = pyro.poutine.block(fn, hide_fn=lambda msg: msg["is_observed"])

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment