Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro Poutine Runtime

From Leeroopedia


Attribute Value
File pyro/poutine/runtime.py
Module pyro.poutine.runtime
Lines 528
Purpose Core runtime infrastructure for Pyro's effect handling system
Architecture Role Foundational layer: global stack, message types, dimension allocation, and stack execution
License Apache-2.0 (Uber Technologies, Inc.)

Overview

runtime.py is the foundational module of the Poutine effect handling system. It defines the global state, data structures, and core algorithms that enable Pyro's composable effect handler architecture.

The module provides:

  • Global state -- The _PYRO_STACK (a list of active Messenger instances) and _PYRO_PARAM_STORE (the global parameter store).
  • Message and InferDict types -- TypedDict definitions that standardize how information flows between handlers and sample sites. Message represents the full site metadata (type, name, fn, value, scale, mask, etc.), while InferDict holds inference-specific configuration (enumeration mode, auxiliary flags, TMC settings).
  • Dimension allocators -- _DimAllocator for plate dimensions and _EnumAllocator for enumeration dimensions.
  • Stack execution -- apply_stack, the core algorithm that processes messages through the handler stack in a two-pass (bottom-to-top, then top-to-bottom) pattern.
  • The effectful decorator -- Wraps functions so that calls to them are intercepted by the handler stack.
  • Utility functions -- am_i_wrapped(), get_mask(), get_plates(), and _inspect().

Code Reference

Global State

# the global pyro stack
_PYRO_STACK: List["Messenger"] = []

# the global ParamStore
_PYRO_PARAM_STORE = ParamStoreDict()

InferDict

A TypedDict for per-site inference configuration:

class InferDict(TypedDict, total=False):
    enumerate: Literal["sequential", "parallel"]
    expand: bool
    is_auxiliary: bool
    is_observed: bool
    num_samples: int
    obs: Optional[torch.Tensor]
    prior: "TorchDistributionMixin"
    tmc: Literal["diagonal", "mixture"]
    was_observed: bool
    # ... internal keys prefixed with _

Message

A TypedDict representing a trace site message:

class Message(TypedDict, Generic[_P, _T], total=False):
    type: str
    name: Optional[str]
    fn: Callable[_P, _T]
    is_observed: bool
    args: Tuple
    kwargs: Dict
    value: Optional[_T]
    scale: Union[torch.Tensor, float]
    mask: Union[bool, torch.Tensor, None]
    cond_indep_stack: Tuple["CondIndepStackFrame", ...]
    done: bool
    stop: bool
    continuation: Optional[Callable[["Message"], None]]
    infer: Optional[InferDict]
    # ... additional computed fields

apply_stack

The core stack execution algorithm:

def apply_stack(initial_msg: Message) -> None:
    stack = _PYRO_STACK
    msg = initial_msg
    pointer = 0
    for frame in reversed(stack):
        pointer = pointer + 1
        frame._process_message(msg)
        if msg["stop"]:
            break
    default_process_message(msg)
    for frame in stack[-pointer:]:
        frame._postprocess_message(msg)
    cont = msg["continuation"]
    if cont is not None:
        cont(msg)

effectful Decorator

def effectful(fn=None, type=None):
    """
    Wrapper for calling apply_stack to apply any active effects.
    """
    @functools.wraps(fn)
    def _fn(*args, name=None, infer=None, obs=None, **kwargs):
        if not am_i_wrapped():
            return fn(*args, **kwargs)
        else:
            msg = Message(type=type, name=name, fn=fn, ...)
            apply_stack(msg)
            return msg["value"]
    return _fn

_DimAllocator

class _DimAllocator:
    """Dimension allocator for plate. Single global instance."""
    def allocate(self, name: str, dim: Optional[int]) -> int: ...
    def free(self, name: str, dim: int) -> None: ...

NonlocalExit

class NonlocalExit(Exception):
    """
    Exception for exiting nonlocally from poutine execution.
    Used by EscapeMessenger to return site information.
    """
    def __init__(self, site: Message, *args, **kwargs): ...
    def reset_stack(self) -> None: ...

I/O Contract

Function/Class Input Output
apply_stack(msg) A Message dict representing a trace site None (mutates msg in place with processed values)
effectful(fn, type) A callable fn and a type string (e.g., "sample") A wrapped callable that routes calls through the handler stack
default_process_message(msg) A Message dict None (sets msg["value"] by calling msg["fn"] if not already done)
am_i_wrapped() (none) bool -- True if the Pyro stack is non-empty
get_mask() (none) None, bool, or torch.Tensor -- the current mask from enclosing handlers
get_plates() (none) Tuple[CondIndepStackFrame, ...] -- the current plate context
_DimAllocator.allocate(name, dim) A plate name and optional dimension An int dimension index (negative)
NonlocalExit(site) A Message dict Exception carrying the site information

Usage Examples

How apply_stack Works

# The stack execution follows a two-pass protocol:
# 1. Process: bottom-to-top (reversed stack), each handler calls _process_message
# 2. Default: if no handler fully handled the message, call fn to sample
# 3. Postprocess: top-to-bottom, each handler calls _postprocess_message
# 4. Continuation: if msg["continuation"] is set, call it

# This is invoked automatically when using pyro.sample() inside a handler context
with poutine.trace() as tr:
    z = pyro.sample("z", dist.Normal(0, 1))
    # Internally: effectful wraps sample -> builds Message -> apply_stack(msg)

Using get_mask for Efficiency

def model():
    # Skip expensive computation when mask is False (during prediction)
    if poutine.get_mask() is not False:
        log_density = my_expensive_computation()
        pyro.factor("foo", log_density)

Checking Wrapped State

from pyro.poutine.runtime import am_i_wrapped

def my_function():
    if am_i_wrapped():
        # We are inside a poutine context
        ...
    else:
        # Direct execution without effect handling
        ...

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment