Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro Pyro Markov

From Leeroopedia


Metadata

Field Value
Implementation ID Pyro_ppl_Pyro_Pyro_Markov
Title pyro.markov / MarkovMessenger
Project Pyro (pyro-ppl/pyro)
File pyro/poutine/markov_messenger.py, Lines 16-101
Implements Pyro_ppl_Pyro_Markov_Dependency
Repository https://github.com/pyro-ppl/pyro

Summary

pyro.markov is a Pyro primitive that declares Markov dependency structure for sequential models. Internally, it is implemented by the MarkovMessenger class, which is a reentrant effect handler that tracks which sample sites are in scope based on a sliding history window. This enables memory-efficient enumeration by allowing the inference engine to contract out old enumeration dimensions.

Signature

The public API is accessed via pyro.markov:

pyro.markov(
    fn=None,
    history=1,
    keep=False,
    dim=None,
    name=None
)

The underlying class:

class MarkovMessenger(ReentrantMessenger):
    def __init__(
        self,
        history: int = 1,
        keep: bool = False,
        dim: Optional[int] = None,
        name: Optional[str] = None,
    ) -> None

Import

import pyro

# Used via pyro.markov, not imported directly
for t in pyro.markov(range(T), history=1):
    ...

Parameters

Parameter Type Default Description
fn iterable, callable, or None None If an iterable (e.g., range(T)), wraps iteration with Markov scoping. If a callable, wraps the function. If None, returns the messenger for use as a context manager.
history int 1 Number of previous contexts visible from the current context. A value of 1 means first-order Markov (each step sees only the previous step). Must be >= 0.
keep bool False If True, frames are kept (replayable) after exiting, enabling dependent branches. If False, frames are dropped, making branches independent.
dim int or None None Optional dimension for vectorized Markov. Currently raises NotImplementedError.
name str or None None Optional name for matching between models and guides. Currently raises NotImplementedError.

Returns

Usage Pattern Return Type Description
pyro.markov(range(T)) Iterator[int] An iterator that yields values from the iterable while managing Markov scope.
pyro.markov(fn) callable A wrapped callable with Markov scoping.
pyro.markov() MarkovMessenger A context manager for manual scope control.

Usage Patterns

Wrapping an Iterable (Most Common)

# First-order HMM
for t in pyro.markov(range(T)):
    state = pyro.sample("state_{}".format(t),
                        dist.Categorical(transition[prev_state]))
    obs = pyro.sample("obs_{}".format(t),
                      dist.Normal(emission[state], 1.0),
                      obs=data[t])

With Custom History

# Second-order Markov model
for t in pyro.markov(range(T), history=2):
    state = pyro.sample("state_{}".format(t),
                        some_second_order_transition(states[-1], states[-2]))

As a Context Manager

# Manual scope management for non-standard iteration
markov = pyro.markov(history=1)
for t in range(T):
    with markov:
        state = pyro.sample("state_{}".format(t), ...)

With keep=True for Dependent Branches

# Tree-structured model where siblings can depend on each other
with pyro.markov(history=1, keep=True):
    left = pyro.sample("left", ...)
with pyro.markov(history=1, keep=True):
    right = pyro.sample("right", ...)  # can depend on "left"

Internal Mechanism

The MarkovMessenger (lines 16-101) extends ReentrantMessenger and maintains:

  • _pos: Current position in the stack (starts at -1, incremented on __enter__)
  • _stack: A list of sets, where each set contains the names of sample sites registered at that stack position

On Each Sample Site (_pyro_sample, lines 84-101)

When a sample site is encountered:

  1. The site's infer dict is augmented with a _markov_scope counter
  2. For positions within the history window (max(0, pos - history) to pos), the site names from those stack frames are added to the scope counter
  3. A _markov_depth integer tracks the nesting level
  4. The current site name is added to the current stack frame
def _pyro_sample(self, msg):
    if msg["done"] or type(msg["fn"]).__name__ == "_Subsample":
        return
    infer = msg["infer"]
    scope = infer.setdefault("_markov_scope", Counter())
    for pos in range(max(0, self._pos - self.history), self._pos + 1):
        scope.update(self._stack[pos])
    infer["_markov_depth"] = 1 + infer.get("_markov_depth", 0)
    self._stack[self._pos].add(msg["name"])

On Context Exit (__exit__, lines 78-82)

When exiting a Markov context:

  • If keep=False: the current stack frame is popped (sites go out of scope)
  • If keep=True: the frame remains in the stack (sites stay visible to later siblings)
  • The position counter is decremented

Iterator Protocol (__iter__, lines 65-70)

When used to wrap an iterable, the messenger enters and exits its context for each value yielded:

def __iter__(self):
    with ExitStack() as stack:
        for value in self._iterable:
            stack.enter_context(self)
            yield value

Complete Example

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate
from pyro.optim import Adam

hidden_dim = 3
T = 50  # sequence length

@config_enumerate
def hmm(data):
    transition = pyro.sample(
        "transition",
        dist.Dirichlet(torch.ones(hidden_dim, hidden_dim)).to_event(1)
    )
    emission = pyro.sample(
        "emission_loc",
        dist.Normal(torch.zeros(hidden_dim), 10.0).to_event(1)
    )
    state = 0
    # pyro.markov enables O(T*K^2) enumeration instead of O(K^T)
    for t in pyro.markov(range(T)):
        state = pyro.sample(
            "state_{}".format(t),
            dist.Categorical(transition[state])
        )
        pyro.sample(
            "obs_{}".format(t),
            dist.Normal(emission[state], 1.0),
            obs=data[t]
        )

@config_enumerate
def guide(data):
    transition_q = pyro.param(
        "transition_q",
        torch.ones(hidden_dim, hidden_dim),
        constraint=dist.constraints.simplex
    )
    pyro.sample(
        "transition",
        dist.Dirichlet(transition_q).to_event(1)
    )
    emission_loc_q = pyro.param("emission_loc_q", torch.zeros(hidden_dim))
    pyro.sample(
        "emission_loc",
        dist.Normal(emission_loc_q, 1.0).to_event(1)
    )

data = torch.randn(T)
elbo = TraceEnum_ELBO(max_plate_nesting=0)
svi = SVI(hmm, guide, Adam({"lr": 0.01}), loss=elbo)

for step in range(500):
    loss = svi.step(data)

Related Pages

Implements Principle

Related Implementations

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment