Implementation:Pyro ppl Pyro Pyro Markov
Metadata
| Field | Value |
|---|---|
| Implementation ID | Pyro_ppl_Pyro_Pyro_Markov |
| Title | pyro.markov / MarkovMessenger |
| Project | Pyro (pyro-ppl/pyro) |
| File | pyro/poutine/markov_messenger.py, Lines 16-101
|
| Implements | Pyro_ppl_Pyro_Markov_Dependency |
| Repository | https://github.com/pyro-ppl/pyro |
Summary
pyro.markov is a Pyro primitive that declares Markov dependency structure for sequential models. Internally, it is implemented by the MarkovMessenger class, which is a reentrant effect handler that tracks which sample sites are in scope based on a sliding history window. This enables memory-efficient enumeration by allowing the inference engine to contract out old enumeration dimensions.
Signature
The public API is accessed via pyro.markov:
pyro.markov(
fn=None,
history=1,
keep=False,
dim=None,
name=None
)
The underlying class:
class MarkovMessenger(ReentrantMessenger):
def __init__(
self,
history: int = 1,
keep: bool = False,
dim: Optional[int] = None,
name: Optional[str] = None,
) -> None
Import
import pyro
# Used via pyro.markov, not imported directly
for t in pyro.markov(range(T), history=1):
...
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
fn |
iterable, callable, or None | None | If an iterable (e.g., range(T)), wraps iteration with Markov scoping. If a callable, wraps the function. If None, returns the messenger for use as a context manager.
|
history |
int | 1 | Number of previous contexts visible from the current context. A value of 1 means first-order Markov (each step sees only the previous step). Must be >= 0. |
keep |
bool | False | If True, frames are kept (replayable) after exiting, enabling dependent branches. If False, frames are dropped, making branches independent. |
dim |
int or None | None | Optional dimension for vectorized Markov. Currently raises NotImplementedError.
|
name |
str or None | None | Optional name for matching between models and guides. Currently raises NotImplementedError.
|
Returns
| Usage Pattern | Return Type | Description |
|---|---|---|
pyro.markov(range(T)) |
Iterator[int] | An iterator that yields values from the iterable while managing Markov scope. |
pyro.markov(fn) |
callable | A wrapped callable with Markov scoping. |
pyro.markov() |
MarkovMessenger | A context manager for manual scope control. |
Usage Patterns
Wrapping an Iterable (Most Common)
# First-order HMM
for t in pyro.markov(range(T)):
state = pyro.sample("state_{}".format(t),
dist.Categorical(transition[prev_state]))
obs = pyro.sample("obs_{}".format(t),
dist.Normal(emission[state], 1.0),
obs=data[t])
With Custom History
# Second-order Markov model
for t in pyro.markov(range(T), history=2):
state = pyro.sample("state_{}".format(t),
some_second_order_transition(states[-1], states[-2]))
As a Context Manager
# Manual scope management for non-standard iteration
markov = pyro.markov(history=1)
for t in range(T):
with markov:
state = pyro.sample("state_{}".format(t), ...)
With keep=True for Dependent Branches
# Tree-structured model where siblings can depend on each other
with pyro.markov(history=1, keep=True):
left = pyro.sample("left", ...)
with pyro.markov(history=1, keep=True):
right = pyro.sample("right", ...) # can depend on "left"
Internal Mechanism
The MarkovMessenger (lines 16-101) extends ReentrantMessenger and maintains:
_pos: Current position in the stack (starts at -1, incremented on__enter__)_stack: A list of sets, where each set contains the names of sample sites registered at that stack position
On Each Sample Site (_pyro_sample, lines 84-101)
When a sample site is encountered:
- The site's
inferdict is augmented with a_markov_scopecounter - For positions within the history window (
max(0, pos - history)topos), the site names from those stack frames are added to the scope counter - A
_markov_depthinteger tracks the nesting level - The current site name is added to the current stack frame
def _pyro_sample(self, msg):
if msg["done"] or type(msg["fn"]).__name__ == "_Subsample":
return
infer = msg["infer"]
scope = infer.setdefault("_markov_scope", Counter())
for pos in range(max(0, self._pos - self.history), self._pos + 1):
scope.update(self._stack[pos])
infer["_markov_depth"] = 1 + infer.get("_markov_depth", 0)
self._stack[self._pos].add(msg["name"])
On Context Exit (__exit__, lines 78-82)
When exiting a Markov context:
- If
keep=False: the current stack frame is popped (sites go out of scope) - If
keep=True: the frame remains in the stack (sites stay visible to later siblings) - The position counter is decremented
Iterator Protocol (__iter__, lines 65-70)
When used to wrap an iterable, the messenger enters and exits its context for each value yielded:
def __iter__(self):
with ExitStack() as stack:
for value in self._iterable:
stack.enter_context(self)
yield value
Complete Example
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate
from pyro.optim import Adam
hidden_dim = 3
T = 50 # sequence length
@config_enumerate
def hmm(data):
transition = pyro.sample(
"transition",
dist.Dirichlet(torch.ones(hidden_dim, hidden_dim)).to_event(1)
)
emission = pyro.sample(
"emission_loc",
dist.Normal(torch.zeros(hidden_dim), 10.0).to_event(1)
)
state = 0
# pyro.markov enables O(T*K^2) enumeration instead of O(K^T)
for t in pyro.markov(range(T)):
state = pyro.sample(
"state_{}".format(t),
dist.Categorical(transition[state])
)
pyro.sample(
"obs_{}".format(t),
dist.Normal(emission[state], 1.0),
obs=data[t]
)
@config_enumerate
def guide(data):
transition_q = pyro.param(
"transition_q",
torch.ones(hidden_dim, hidden_dim),
constraint=dist.constraints.simplex
)
pyro.sample(
"transition",
dist.Dirichlet(transition_q).to_event(1)
)
emission_loc_q = pyro.param("emission_loc_q", torch.zeros(hidden_dim))
pyro.sample(
"emission_loc",
dist.Normal(emission_loc_q, 1.0).to_event(1)
)
data = torch.randn(T)
elbo = TraceEnum_ELBO(max_plate_nesting=0)
svi = SVI(hmm, guide, Adam({"lr": 0.01}), loss=elbo)
for step in range(500):
loss = svi.step(data)
Related Pages
Implements Principle
Related Implementations
- Pyro_ppl_Pyro_Config_Enumerate -- Configures which discrete sites are enumerated; Markov structure makes that enumeration memory-efficient.
- Pyro_ppl_Pyro_Contract_Tensor_Tree -- Exploits Markov structure during tensor contraction to contract dimensions incrementally.
- Pyro_ppl_Pyro_Infer_Discrete -- Uses Markov structure for efficient forward-backward and Viterbi decoding.