Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro Poutine Handlers

From Leeroopedia
Revision as of 16:24, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Pyro_ppl_Pyro_Poutine_Handlers.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Attribute Value
File pyro/poutine/handlers.py
Module pyro.poutine.handlers
Lines 677
Purpose User-facing factory functions that create and compose Poutine effect handlers
Architecture Role Public API entry point for all Poutine effect handlers
License Apache-2.0 (Uber Technologies, Inc.)

Overview

handlers.py serves as the primary user-facing module for the Poutine effect handling subsystem in Pyro. It provides convenient factory functions that wrap the underlying Messenger classes, allowing users to apply effect handlers as higher-order functions, decorators, or context managers.

The module defines two categories of operations:

  • Primitive operations -- thin wrappers around individual Messenger classes, generated via the _make_handler decorator factory. These include: block, broadcast, collapse, condition, do, enum, escape, equalize, infer_config, lift, mask, reparam, replay, scale, seed, substitute, trace, and uncondition.
  • Composite operations -- higher-level constructs that combine multiple primitive handlers, such as queue (for sequential enumeration) and markov (for Markov dependency declaration).

Every factory function follows a uniform pattern: when called with a callable fn, it returns a wrapped version of fn; when called without fn (or with fn=None), it returns the Messenger instance directly, which can then be used as a context manager.

Code Reference

_make_handler

The internal decorator factory that generates handler wrapper functions from Messenger classes:

def _make_handler(msngr_cls, module=None):
    def handler_decorator(func):
        @functools.wraps(func)
        def handler(fn=None, *args, **kwargs):
            if fn is not None and not (
                callable(fn) or isinstance(fn, collections.abc.Iterable)
            ):
                raise ValueError(
                    f"{fn} is not callable, did you mean to pass it as a keyword arg?"
                )
            msngr = msngr_cls(*args, **kwargs)
            return (
                functools.update_wrapper(msngr(fn), fn, updated=())
                if fn is not None
                else msngr
            )
        return handler
    return handler_decorator

Handler Factory Functions (Selected)

Each primitive handler is generated via @_make_handler(MessengerClass). For example:

@_make_handler(BlockMessenger)
def block(
    fn=None, hide_fn=None, expose_fn=None,
    hide_all=True, expose_all=False,
    hide=None, expose=None,
    hide_types=None, expose_types=None,
): ...

@_make_handler(ConditionMessenger)
def condition(fn, data): ...

@_make_handler(TraceMessenger)
def trace(fn=None, graph_type=None, param_only=None): ...

queue (Composite Operation)

def queue(fn=None, queue=None, max_tries=None,
          extend_fn=None, escape_fn=None, num_samples=None):
    """
    Used in sequential enumeration over discrete variables.
    Given a stochastic function and a queue,
    return a return value from a complete trace in the queue.
    """

markov (Composite Operation)

def markov(fn=None, history=1, keep=False, dim=None, name=None):
    """
    Markov dependency declaration.
    Can be used as a context manager, decorator, or iterator.
    """

I/O Contract

Function Input Output
Any handler(fn, ...) A callable fn plus handler-specific keyword arguments A wrapped callable with the same signature as fn
Any handler(None, ...) or handler(...) Handler-specific keyword arguments only (no fn) A Messenger instance usable as a context manager
queue(fn, queue, ...) A stochastic function, a queue data structure, and configuration parameters A callable that executes enumeration over discrete variables
markov(fn, ...) A callable, iterable, or None; plus history, keep, dim, name A MarkovMessenger or a wrapped callable

All handler functions support three usage patterns:

  1. Higher-order function: wrapped = poutine.trace(model)
  2. Decorator: @poutine.condition(data={"z": 1.0})
  3. Context manager: with poutine.trace() as tr:

Usage Examples

Basic Handler Composition

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
import torch

def model(x):
    s = pyro.param("s", torch.tensor(0.5))
    z = pyro.sample("z", dist.Normal(x, s))
    return z ** 2

# Condition and trace a model
conditioned_model = poutine.condition(model, data={"z": 1.0})
traced_model = poutine.trace(conditioned_model)

Computing Monte Carlo ELBO

guide_tr = poutine.trace(guide).get_trace(...)
model_tr = poutine.trace(
    poutine.replay(conditioned_model, trace=guide_tr)
).get_trace(...)
monte_carlo_elbo = model_tr.log_prob_sum() - guide_tr.log_prob_sum()

Using Handlers as Context Managers

with poutine.condition(data={"z": 1.0}):
    s = pyro.param("s", torch.tensor(0.5))
    z = pyro.sample("z", dist.Normal(0., s))
    y = z ** 2

Markov Dependency as Iterator

for i in poutine.markov(range(10)):
    z = pyro.sample(f"z_{i}", dist.Normal(0, 1))

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment