Implementation:Pyro ppl Pyro DoMessenger
| Attribute | Value |
|---|---|
| File | pyro/poutine/do_messenger.py
|
| Module | pyro.poutine.do_messenger
|
| Lines | 95 |
| Parent Class | Messenger
|
| Purpose | Implement causal interventions (do-calculus) on sample sites |
| License | Apache-2.0 (Uber Technologies, Inc.) |
Overview
DoMessenger implements causal interventions following the do-calculus framework and Single World Intervention Graphs (SWIG) [1]. When a sample site's name matches a key in the provided data dictionary, the handler:
- Creates a fresh sample site (via
apply_stack) to record the natural value that would have been sampled. - Mangles the original site name by appending
"__CF"(counterfactual). - Sets the return value of the original site to the intervention value.
- Marks the site as observed and stops further processing.
This "node splitting" approach means the intervention value propagates downstream while the original distribution is still recorded in the trace. This composes freely with condition to represent counterfactual distributions over potential outcomes.
References: [1] Thomas Richardson, James Robins, "Single World Intervention Graphs: A Primer"
Code Reference
class DoMessenger(Messenger):
def __init__(self, data: Dict[str, Union[torch.Tensor, numbers.Number]]) -> None:
super().__init__()
self.data = data
self._intervener_id = str(id(self))
def _pyro_sample(self, msg: Message) -> None:
if (msg.get("_intervener_id") != self._intervener_id
and self.data.get(msg["name"]) is not None):
if msg.get("_intervener_id") is not None:
warnings.warn("Attempting to intervene on variable {} multiple times, ...")
msg["_intervener_id"] = self._intervener_id
# Split node: create fresh sample site with original distribution
new_msg = msg.copy()
new_msg["cond_indep_stack"] = ()
apply_stack(new_msg)
# Apply intervention
intervention = self.data[msg["name"]]
msg["name"] = msg["name"] + "__CF" # mangle old name
if isinstance(intervention, (numbers.Number, torch.Tensor)):
msg["value"] = torch.tensor(intervention) if isinstance(intervention, numbers.Number) else intervention
msg["is_observed"] = True
msg["stop"] = True
else:
raise NotImplementedError(...)
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| data | Dict[str, Union[torch.Tensor, numbers.Number]]
|
A dict mapping sample site names to intervention values |
| Message Effect | Description |
|---|---|
| Node splitting | A new sample from the original distribution is created via apply_stack
|
| msg["name"] | Original name is mangled with "__CF" suffix
|
| msg["value"] | Set to the intervention value (converted to tensor if needed) |
| msg["is_observed"] | Set to True
|
| msg["stop"] | Set to True to prevent further processing
|
| msg["_intervener_id"] | Set to prevent recursive re-intervention |
Usage Examples
Basic Intervention
def model(x):
s = pyro.param("s", torch.tensor(0.5))
z = pyro.sample("z", dist.Normal(x, s))
return z ** 2
# Intervene: set z = 1.0 regardless of the distribution
intervened_model = pyro.poutine.do(model, data={"z": torch.tensor(1.)})
# This is equivalent to replacing z = pyro.sample("z", ...) with z = tensor(1.)
# while recording a fresh sample from Normal(x, s) under the name "z"
Counterfactual Reasoning
# Compose do() with condition() for counterfactual analysis
counterfactual_model = pyro.poutine.do(
pyro.poutine.condition(model, data={"outcome": observed_outcome}),
data={"treatment": torch.tensor(1.0)}
)
Related Pages
- Pyro_ppl_Pyro_Messenger_Base -- Parent class providing the handler protocol
- Pyro_ppl_Pyro_Poutine_Handlers -- The
poutine.do()factory function - Pyro_ppl_Pyro_ConditionMessenger -- Conditioning (observation) vs intervention
- Pyro_ppl_Pyro_Poutine_Runtime --
apply_stackis called directly for node splitting