Implementation:Pyro ppl Pyro Poutine Handlers
| Attribute | Value |
|---|---|
| File | pyro/poutine/handlers.py
|
| Module | pyro.poutine.handlers
|
| Lines | 677 |
| Purpose | User-facing factory functions that create and compose Poutine effect handlers |
| Architecture Role | Public API entry point for all Poutine effect handlers |
| License | Apache-2.0 (Uber Technologies, Inc.) |
Overview
handlers.py serves as the primary user-facing module for the Poutine effect handling subsystem in Pyro. It provides convenient factory functions that wrap the underlying Messenger classes, allowing users to apply effect handlers as higher-order functions, decorators, or context managers.
The module defines two categories of operations:
- Primitive operations -- thin wrappers around individual Messenger classes, generated via the
_make_handlerdecorator factory. These include:block,broadcast,collapse,condition,do,enum,escape,equalize,infer_config,lift,mask,reparam,replay,scale,seed,substitute,trace, anduncondition. - Composite operations -- higher-level constructs that combine multiple primitive handlers, such as
queue(for sequential enumeration) andmarkov(for Markov dependency declaration).
Every factory function follows a uniform pattern: when called with a callable fn, it returns a wrapped version of fn; when called without fn (or with fn=None), it returns the Messenger instance directly, which can then be used as a context manager.
Code Reference
_make_handler
The internal decorator factory that generates handler wrapper functions from Messenger classes:
def _make_handler(msngr_cls, module=None):
def handler_decorator(func):
@functools.wraps(func)
def handler(fn=None, *args, **kwargs):
if fn is not None and not (
callable(fn) or isinstance(fn, collections.abc.Iterable)
):
raise ValueError(
f"{fn} is not callable, did you mean to pass it as a keyword arg?"
)
msngr = msngr_cls(*args, **kwargs)
return (
functools.update_wrapper(msngr(fn), fn, updated=())
if fn is not None
else msngr
)
return handler
return handler_decorator
Handler Factory Functions (Selected)
Each primitive handler is generated via @_make_handler(MessengerClass). For example:
@_make_handler(BlockMessenger)
def block(
fn=None, hide_fn=None, expose_fn=None,
hide_all=True, expose_all=False,
hide=None, expose=None,
hide_types=None, expose_types=None,
): ...
@_make_handler(ConditionMessenger)
def condition(fn, data): ...
@_make_handler(TraceMessenger)
def trace(fn=None, graph_type=None, param_only=None): ...
queue (Composite Operation)
def queue(fn=None, queue=None, max_tries=None,
extend_fn=None, escape_fn=None, num_samples=None):
"""
Used in sequential enumeration over discrete variables.
Given a stochastic function and a queue,
return a return value from a complete trace in the queue.
"""
markov (Composite Operation)
def markov(fn=None, history=1, keep=False, dim=None, name=None):
"""
Markov dependency declaration.
Can be used as a context manager, decorator, or iterator.
"""
I/O Contract
| Function | Input | Output |
|---|---|---|
| Any handler(fn, ...) | A callable fn plus handler-specific keyword arguments
|
A wrapped callable with the same signature as fn
|
| Any handler(None, ...) or handler(...) | Handler-specific keyword arguments only (no fn)
|
A Messenger instance usable as a context manager |
| queue(fn, queue, ...) | A stochastic function, a queue data structure, and configuration parameters | A callable that executes enumeration over discrete variables |
| markov(fn, ...) | A callable, iterable, or None; plus history, keep, dim, name
|
A MarkovMessenger or a wrapped callable
|
All handler functions support three usage patterns:
- Higher-order function:
wrapped = poutine.trace(model) - Decorator:
@poutine.condition(data={"z": 1.0}) - Context manager:
with poutine.trace() as tr:
Usage Examples
Basic Handler Composition
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
import torch
def model(x):
s = pyro.param("s", torch.tensor(0.5))
z = pyro.sample("z", dist.Normal(x, s))
return z ** 2
# Condition and trace a model
conditioned_model = poutine.condition(model, data={"z": 1.0})
traced_model = poutine.trace(conditioned_model)
Computing Monte Carlo ELBO
guide_tr = poutine.trace(guide).get_trace(...)
model_tr = poutine.trace(
poutine.replay(conditioned_model, trace=guide_tr)
).get_trace(...)
monte_carlo_elbo = model_tr.log_prob_sum() - guide_tr.log_prob_sum()
Using Handlers as Context Managers
with poutine.condition(data={"z": 1.0}):
s = pyro.param("s", torch.tensor(0.5))
z = pyro.sample("z", dist.Normal(0., s))
y = z ** 2
Markov Dependency as Iterator
for i in poutine.markov(range(10)):
z = pyro.sample(f"z_{i}", dist.Normal(0, 1))
Related Pages
- Pyro_ppl_Pyro_Messenger_Base -- The base
Messengerclass that all handlers extend - Pyro_ppl_Pyro_Poutine_Runtime -- Core runtime infrastructure (
apply_stack,effectful) - Pyro_ppl_Pyro_Trace_Struct -- The
Tracedata structure populated bypoutine.trace - Pyro_ppl_Pyro_BlockMessenger -- BlockMessenger implementation
- Pyro_ppl_Pyro_ConditionMessenger -- ConditionMessenger implementation
- Pyro_ppl_Pyro_ReplayMessenger -- ReplayMessenger implementation
- Pyro_ppl_Pyro_MarkovMessenger -- MarkovMessenger implementation