Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro ReplayMessenger

From Leeroopedia


Attribute Value
File pyro/poutine/replay_messenger.py
Module pyro.poutine.replay_messenger
Lines 91
Parent Class Messenger
Purpose Replay previously recorded values at sample and param sites
Architecture Role Core component of the trace-replay pattern used in variational inference
License Apache-2.0 (Uber Technologies, Inc.)

Overview

ReplayMessenger replays previously recorded values at sample and param sites, making the execution follow the same random choices as a previous run. This is fundamental to the trace-replay pattern used extensively in variational inference.

The handler can replay from two sources:

  • A Trace object -- values and infer dicts from matching sample sites in the trace are injected.
  • A params dict -- constrained parameter values are injected at matching param sites.

For sample sites:

  • Observed sites are skipped (not replayed).
  • The replayed site must be of type "sample" and must not be observed in the source trace.
  • Both the value and the infer dict are copied from the source.
  • The message is marked as done to prevent further sampling.

For param sites:

  • The value must have an unconstrained attribute (indicating it is a constrained parameter value).

Code Reference

class ReplayMessenger(Messenger):
    def __init__(self, trace=None, params=None):
        super().__init__()
        if trace is None and params is None:
            raise ValueError("must provide trace or params to replay against")
        self.trace = trace
        self.params = params

    def _pyro_sample(self, msg):
        name = msg["name"]
        if self.trace is not None and name in self.trace:
            guide_msg = self.trace.nodes[name]
            if msg["is_observed"]:
                return None
            if guide_msg["type"] != "sample" or guide_msg["is_observed"]:
                raise RuntimeError("site {} must be sampled in trace".format(name))
            msg["done"] = True
            msg["value"] = guide_msg["value"]
            msg["infer"] = guide_msg["infer"]

    def _pyro_param(self, msg):
        name = msg["name"]
        if self.params is not None and name in self.params:
            assert hasattr(self.params[name], "unconstrained"), \
                "param {} must be constrained value".format(name)
            msg["done"] = True
            msg["value"] = self.params[name]

I/O Contract

Parameter Type Description
trace Optional[Trace] A Trace object whose sample site values will be replayed
params Optional[Dict[str, torch.Tensor]] A dict of param names to constrained values to replay

At least one of trace or params must be provided.

Message Effect Description
msg["done"] Set to True for replayed sites (prevents further sampling)
msg["value"] Set to the recorded value from the trace or params dict
msg["infer"] Copied from the source trace for sample sites

Usage Examples

Basic Trace Replay

def model(x):
    s = pyro.param("s", torch.tensor(0.5))
    z = pyro.sample("z", dist.Normal(x, s))
    return z ** 2

old_trace = pyro.poutine.trace(model).get_trace(1.0)
replayed_model = pyro.poutine.replay(model, trace=old_trace)
assert replayed_model(0.0) == old_trace.nodes["_RETURN"]["value"]

Trace-Replay Pattern for ELBO

# Standard trace-replay pattern in variational inference
guide_tr = pyro.poutine.trace(guide).get_trace(data)
model_tr = pyro.poutine.trace(
    pyro.poutine.replay(model, trace=guide_tr)
).get_trace(data)
elbo = model_tr.log_prob_sum() - guide_tr.log_prob_sum()

Replaying Parameters

saved_params = {"my_param": constrained_value}
replayed = pyro.poutine.replay(model, params=saved_params)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment