Implementation:Pyro ppl Pyro ConditionMessenger
| Attribute | Value |
|---|---|
| File | pyro/poutine/condition_messenger.py
|
| Module | pyro.poutine.condition_messenger
|
| Lines | 72 |
| Parent Class | Messenger
|
| Purpose | Convert sample sites into observed sites with given values |
| License | Apache-2.0 (Uber Technologies, Inc.) |
Overview
ConditionMessenger converts sample sites into observed sites by injecting fixed values from a data dictionary or a Trace object. When a sample site's name matches a key in the provided data, the site's value is set to the data value and is_observed is set to True.
This is equivalent to adding obs=value as a keyword argument to pyro.sample(), but applied externally. This handler is fundamental to Bayesian inference as it specifies which variables are observed.
The data source can be either:
- A dict mapping site names to
torch.Tensorvalues. - A Trace object, from which values are extracted via
trace.nodes[name]["value"].
Code Reference
class ConditionMessenger(Messenger):
def __init__(self, data: Union[Dict[str, torch.Tensor], Trace]) -> None:
super().__init__()
self.data = data
def _pyro_sample(self, msg: "Message") -> None:
name = msg["name"]
if name in self.data:
if isinstance(self.data, Trace):
msg["value"] = self.data.nodes[name]["value"]
else:
msg["value"] = self.data[name]
msg["is_observed"] = msg["value"] is not None
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| data | Dict[str, torch.Tensor] or Trace
|
The observations to condition on, keyed by site name |
| Message Effect | Description |
|---|---|
| msg["value"] | Set to the value from data for matching sites
|
| msg["is_observed"] | Set to True when value is not None
|
Sites not present in data are unaffected (default sampling behavior).
Usage Examples
Conditioning with a Dictionary
def model(x):
s = pyro.param("s", torch.tensor(0.5))
z = pyro.sample("z", dist.Normal(x, s))
return z ** 2
conditioned_model = pyro.poutine.condition(model, data={"z": torch.tensor(1.)})
Conditioning with a Trace
old_trace = pyro.poutine.trace(model).get_trace(0.0)
conditioned_model = pyro.poutine.condition(model, data=old_trace)
Using as Context Manager
with pyro.poutine.condition(data={"z": torch.tensor(1.0)}):
result = model(0.0)
Related Pages
- Pyro_ppl_Pyro_Messenger_Base -- Parent class providing the handler protocol
- Pyro_ppl_Pyro_Poutine_Handlers -- The
poutine.condition()factory function - Pyro_ppl_Pyro_UnconditionMessenger -- The inverse operation: removes observations
- Pyro_ppl_Pyro_DoMessenger -- Related: interventions vs conditioning
- Pyro_ppl_Pyro_Trace_Struct -- Can be used as the data source for conditioning
- Pyro_ppl_Pyro_SubstituteMessenger -- Similar but for
pyro.paramsites