Implementation:Pyro ppl Pyro GuideMessenger
| Attribute | Value |
|---|---|
| File | pyro/poutine/guide.py
|
| Module | pyro.poutine.guide
|
| Lines | 159 |
| Parent Class | TraceMessenger (which extends Messenger)
|
| Purpose | Abstract base class for effect-based guide implementations |
| Architecture Role | Enables interleaved model/guide computation for variational inference |
| License | Apache-2.0 (Pyro project contributors) |
Overview
GuideMessenger is an abstract base class for implementing variational guides (approximate posterior distributions) using the effect handler pattern. Unlike traditional guide implementations that require separate model and guide functions with explicit trace-replay, GuideMessenger interleaves model and guide computations in a single forward pass.
Derived classes must implement the get_posterior(name, prior) method, which receives the prior distribution at each sample site and returns a posterior distribution or a point estimate.
Key features:
- Interleaved execution -- The guide intercepts each sample site during model execution, replacing the prior with a posterior distribution.
- Single-pass trace generation -- Both model and guide traces are produced from a single call, accessible via
get_traces(). - Prior access -- The original prior distribution is stored in
msg["infer"]["prior"]for use by the guide. - Upstream value access --
upstream_value(name)provides access to values of previously sampled sites.
Code Reference
class GuideMessenger(TraceMessenger, ABC):
def __init__(self, model: Callable) -> None:
super().__init__()
self._model = (model,)
@property
def model(self) -> Callable:
return self._model[0]
def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]:
"""Draws posterior samples and replays model against them."""
self.args_kwargs = args, kwargs
try:
with self:
self.model(*args, **kwargs)
finally:
del self.args_kwargs
model_trace, guide_trace = self.get_traces()
samples = {}
for name, site in model_trace.nodes.items():
if site["type"] == "sample":
samples[name] = site["value"]
return samples
def _pyro_sample(self, msg: "Message") -> None:
if msg["is_observed"] or site_is_subsample(msg):
return
prior = msg["fn"]
msg["infer"]["prior"] = prior
posterior = self.get_posterior(msg["name"], prior)
if isinstance(posterior, torch.Tensor):
posterior = dist.Delta(posterior, event_dim=prior.event_dim)
if posterior.batch_shape != prior.batch_shape:
posterior = posterior.expand(prior.batch_shape)
msg["fn"] = posterior
@abstractmethod
def get_posterior(self, name, prior):
"""Compute posterior distribution given a prior. Must be implemented."""
raise NotImplementedError
def upstream_value(self, name: str) -> Optional[torch.Tensor]:
"""Access the value of an upstream sample or deterministic site."""
return self.trace.nodes[name]["value"]
def get_traces(self) -> Tuple[Trace, Trace]:
"""Extract (model_trace, guide_trace) pair after execution."""
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| model | Callable
|
The generative model function to be guided |
| Method | Input | Output |
|---|---|---|
| __call__(*args, **kwargs) | Model arguments | Dict[str, torch.Tensor] mapping site names to sampled values
|
| get_posterior(name, prior) | Site name and prior distribution | A posterior Distribution or torch.Tensor (point estimate)
|
| upstream_value(name) | A site name | Optional[torch.Tensor] -- the sampled value at that site
|
| get_traces() | (none) | Tuple[Trace, Trace] -- (model_trace, guide_trace)
|
Usage Examples
Implementing a Custom Guide
class MyGuide(GuideMessenger):
def __init__(self, model):
super().__init__(model)
self.loc = torch.nn.Parameter(torch.zeros(1))
self.scale = torch.nn.Parameter(torch.ones(1))
def get_posterior(self, name, prior):
if name == "z":
return dist.Normal(self.loc, self.scale.exp())
# Fall back to prior for other sites
return prior
guide = MyGuide(model)
samples = guide(x_data) # Returns dict of sampled values
model_trace, guide_trace = guide.get_traces()
Accessing Upstream Values
class AutoregressiveGuide(GuideMessenger):
def get_posterior(self, name, prior):
if name == "z2":
z1_value = self.upstream_value("z1")
# Use z1 to parameterize the posterior for z2
return dist.Normal(z1_value, 1.0)
return prior
Related Pages
- Pyro_ppl_Pyro_Messenger_Base -- Base handler protocol inherited via TraceMessenger
- Pyro_ppl_Pyro_Poutine_Handlers -- No direct factory function; GuideMessenger is used via subclassing
- Pyro_ppl_Pyro_Trace_Struct -- The
Traceobjects returned byget_traces() - Pyro_ppl_Pyro_ReplayMessenger -- Alternative approach: separate trace-replay pattern
- Pyro_ppl_Pyro_ConditionMessenger -- Observed sites are skipped by
GuideMessenger._pyro_sample