Implementation:Pyro ppl Pyro Poutine Runtime
Appearance
| Attribute | Value |
|---|---|
| File | pyro/poutine/runtime.py
|
| Module | pyro.poutine.runtime
|
| Lines | 528 |
| Purpose | Core runtime infrastructure for Pyro's effect handling system |
| Architecture Role | Foundational layer: global stack, message types, dimension allocation, and stack execution |
| License | Apache-2.0 (Uber Technologies, Inc.) |
Overview
runtime.py is the foundational module of the Poutine effect handling system. It defines the global state, data structures, and core algorithms that enable Pyro's composable effect handler architecture.
The module provides:
- Global state -- The
_PYRO_STACK(a list of active Messenger instances) and_PYRO_PARAM_STORE(the global parameter store). - Message and InferDict types --
TypedDictdefinitions that standardize how information flows between handlers and sample sites.Messagerepresents the full site metadata (type, name, fn, value, scale, mask, etc.), whileInferDictholds inference-specific configuration (enumeration mode, auxiliary flags, TMC settings). - Dimension allocators --
_DimAllocatorfor plate dimensions and_EnumAllocatorfor enumeration dimensions. - Stack execution --
apply_stack, the core algorithm that processes messages through the handler stack in a two-pass (bottom-to-top, then top-to-bottom) pattern. - The effectful decorator -- Wraps functions so that calls to them are intercepted by the handler stack.
- Utility functions --
am_i_wrapped(),get_mask(),get_plates(), and_inspect().
Code Reference
Global State
# the global pyro stack
_PYRO_STACK: List["Messenger"] = []
# the global ParamStore
_PYRO_PARAM_STORE = ParamStoreDict()
InferDict
A TypedDict for per-site inference configuration:
class InferDict(TypedDict, total=False):
enumerate: Literal["sequential", "parallel"]
expand: bool
is_auxiliary: bool
is_observed: bool
num_samples: int
obs: Optional[torch.Tensor]
prior: "TorchDistributionMixin"
tmc: Literal["diagonal", "mixture"]
was_observed: bool
# ... internal keys prefixed with _
Message
A TypedDict representing a trace site message:
class Message(TypedDict, Generic[_P, _T], total=False):
type: str
name: Optional[str]
fn: Callable[_P, _T]
is_observed: bool
args: Tuple
kwargs: Dict
value: Optional[_T]
scale: Union[torch.Tensor, float]
mask: Union[bool, torch.Tensor, None]
cond_indep_stack: Tuple["CondIndepStackFrame", ...]
done: bool
stop: bool
continuation: Optional[Callable[["Message"], None]]
infer: Optional[InferDict]
# ... additional computed fields
apply_stack
The core stack execution algorithm:
def apply_stack(initial_msg: Message) -> None:
stack = _PYRO_STACK
msg = initial_msg
pointer = 0
for frame in reversed(stack):
pointer = pointer + 1
frame._process_message(msg)
if msg["stop"]:
break
default_process_message(msg)
for frame in stack[-pointer:]:
frame._postprocess_message(msg)
cont = msg["continuation"]
if cont is not None:
cont(msg)
effectful Decorator
def effectful(fn=None, type=None):
"""
Wrapper for calling apply_stack to apply any active effects.
"""
@functools.wraps(fn)
def _fn(*args, name=None, infer=None, obs=None, **kwargs):
if not am_i_wrapped():
return fn(*args, **kwargs)
else:
msg = Message(type=type, name=name, fn=fn, ...)
apply_stack(msg)
return msg["value"]
return _fn
_DimAllocator
class _DimAllocator:
"""Dimension allocator for plate. Single global instance."""
def allocate(self, name: str, dim: Optional[int]) -> int: ...
def free(self, name: str, dim: int) -> None: ...
NonlocalExit
class NonlocalExit(Exception):
"""
Exception for exiting nonlocally from poutine execution.
Used by EscapeMessenger to return site information.
"""
def __init__(self, site: Message, *args, **kwargs): ...
def reset_stack(self) -> None: ...
I/O Contract
| Function/Class | Input | Output |
|---|---|---|
| apply_stack(msg) | A Message dict representing a trace site
|
None (mutates msg in place with processed values)
|
| effectful(fn, type) | A callable fn and a type string (e.g., "sample")
|
A wrapped callable that routes calls through the handler stack |
| default_process_message(msg) | A Message dict
|
None (sets msg["value"] by calling msg["fn"] if not already done)
|
| am_i_wrapped() | (none) | bool -- True if the Pyro stack is non-empty
|
| get_mask() | (none) | None, bool, or torch.Tensor -- the current mask from enclosing handlers
|
| get_plates() | (none) | Tuple[CondIndepStackFrame, ...] -- the current plate context
|
| _DimAllocator.allocate(name, dim) | A plate name and optional dimension | An int dimension index (negative)
|
| NonlocalExit(site) | A Message dict
|
Exception carrying the site information |
Usage Examples
How apply_stack Works
# The stack execution follows a two-pass protocol:
# 1. Process: bottom-to-top (reversed stack), each handler calls _process_message
# 2. Default: if no handler fully handled the message, call fn to sample
# 3. Postprocess: top-to-bottom, each handler calls _postprocess_message
# 4. Continuation: if msg["continuation"] is set, call it
# This is invoked automatically when using pyro.sample() inside a handler context
with poutine.trace() as tr:
z = pyro.sample("z", dist.Normal(0, 1))
# Internally: effectful wraps sample -> builds Message -> apply_stack(msg)
Using get_mask for Efficiency
def model():
# Skip expensive computation when mask is False (during prediction)
if poutine.get_mask() is not False:
log_density = my_expensive_computation()
pyro.factor("foo", log_density)
Checking Wrapped State
from pyro.poutine.runtime import am_i_wrapped
def my_function():
if am_i_wrapped():
# We are inside a poutine context
...
else:
# Direct execution without effect handling
...
Related Pages
- Pyro_ppl_Pyro_Messenger_Base -- The base Messenger class that populates the
_PYRO_STACK - Pyro_ppl_Pyro_Poutine_Handlers -- User-facing factory functions that compose effect handlers
- Pyro_ppl_Pyro_Trace_Struct -- The Trace data structure that stores processed Messages
- Pyro_ppl_Pyro_IndepMessenger -- Uses
_DimAllocatorfor plate dimension management - Pyro_ppl_Pyro_EscapeMessenger -- Uses
NonlocalExitfor non-local control flow
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment