Implementation:Pyro ppl Pyro ReplayMessenger
Appearance
| Attribute | Value |
|---|---|
| File | pyro/poutine/replay_messenger.py
|
| Module | pyro.poutine.replay_messenger
|
| Lines | 91 |
| Parent Class | Messenger
|
| Purpose | Replay previously recorded values at sample and param sites |
| Architecture Role | Core component of the trace-replay pattern used in variational inference |
| License | Apache-2.0 (Uber Technologies, Inc.) |
Overview
ReplayMessenger replays previously recorded values at sample and param sites, making the execution follow the same random choices as a previous run. This is fundamental to the trace-replay pattern used extensively in variational inference.
The handler can replay from two sources:
- A Trace object -- values and infer dicts from matching sample sites in the trace are injected.
- A params dict -- constrained parameter values are injected at matching param sites.
For sample sites:
- Observed sites are skipped (not replayed).
- The replayed site must be of type
"sample"and must not be observed in the source trace. - Both the value and the
inferdict are copied from the source. - The message is marked as
doneto prevent further sampling.
For param sites:
- The value must have an
unconstrainedattribute (indicating it is a constrained parameter value).
Code Reference
class ReplayMessenger(Messenger):
def __init__(self, trace=None, params=None):
super().__init__()
if trace is None and params is None:
raise ValueError("must provide trace or params to replay against")
self.trace = trace
self.params = params
def _pyro_sample(self, msg):
name = msg["name"]
if self.trace is not None and name in self.trace:
guide_msg = self.trace.nodes[name]
if msg["is_observed"]:
return None
if guide_msg["type"] != "sample" or guide_msg["is_observed"]:
raise RuntimeError("site {} must be sampled in trace".format(name))
msg["done"] = True
msg["value"] = guide_msg["value"]
msg["infer"] = guide_msg["infer"]
def _pyro_param(self, msg):
name = msg["name"]
if self.params is not None and name in self.params:
assert hasattr(self.params[name], "unconstrained"), \
"param {} must be constrained value".format(name)
msg["done"] = True
msg["value"] = self.params[name]
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| trace | Optional[Trace]
|
A Trace object whose sample site values will be replayed
|
| params | Optional[Dict[str, torch.Tensor]]
|
A dict of param names to constrained values to replay |
At least one of trace or params must be provided.
| Message Effect | Description |
|---|---|
| msg["done"] | Set to True for replayed sites (prevents further sampling)
|
| msg["value"] | Set to the recorded value from the trace or params dict |
| msg["infer"] | Copied from the source trace for sample sites |
Usage Examples
Basic Trace Replay
def model(x):
s = pyro.param("s", torch.tensor(0.5))
z = pyro.sample("z", dist.Normal(x, s))
return z ** 2
old_trace = pyro.poutine.trace(model).get_trace(1.0)
replayed_model = pyro.poutine.replay(model, trace=old_trace)
assert replayed_model(0.0) == old_trace.nodes["_RETURN"]["value"]
Trace-Replay Pattern for ELBO
# Standard trace-replay pattern in variational inference
guide_tr = pyro.poutine.trace(guide).get_trace(data)
model_tr = pyro.poutine.trace(
pyro.poutine.replay(model, trace=guide_tr)
).get_trace(data)
elbo = model_tr.log_prob_sum() - guide_tr.log_prob_sum()
Replaying Parameters
saved_params = {"my_param": constrained_value}
replayed = pyro.poutine.replay(model, params=saved_params)
Related Pages
- Pyro_ppl_Pyro_Messenger_Base -- Parent class providing the handler protocol
- Pyro_ppl_Pyro_Poutine_Handlers -- The
poutine.replay()factory function - Pyro_ppl_Pyro_Trace_Struct -- The
Traceobject used as the replay source - Pyro_ppl_Pyro_ConditionMessenger -- Related: conditioning fixes values as observations
- Pyro_ppl_Pyro_GuideMessenger -- Alternative to trace-replay: interleaved execution
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment