Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro DoMessenger

From Leeroopedia


Attribute Value
File pyro/poutine/do_messenger.py
Module pyro.poutine.do_messenger
Lines 95
Parent Class Messenger
Purpose Implement causal interventions (do-calculus) on sample sites
License Apache-2.0 (Uber Technologies, Inc.)

Overview

DoMessenger implements causal interventions following the do-calculus framework and Single World Intervention Graphs (SWIG) [1]. When a sample site's name matches a key in the provided data dictionary, the handler:

  1. Creates a fresh sample site (via apply_stack) to record the natural value that would have been sampled.
  2. Mangles the original site name by appending "__CF" (counterfactual).
  3. Sets the return value of the original site to the intervention value.
  4. Marks the site as observed and stops further processing.

This "node splitting" approach means the intervention value propagates downstream while the original distribution is still recorded in the trace. This composes freely with condition to represent counterfactual distributions over potential outcomes.

References: [1] Thomas Richardson, James Robins, "Single World Intervention Graphs: A Primer"

Code Reference

class DoMessenger(Messenger):
    def __init__(self, data: Dict[str, Union[torch.Tensor, numbers.Number]]) -> None:
        super().__init__()
        self.data = data
        self._intervener_id = str(id(self))

    def _pyro_sample(self, msg: Message) -> None:
        if (msg.get("_intervener_id") != self._intervener_id
                and self.data.get(msg["name"]) is not None):
            if msg.get("_intervener_id") is not None:
                warnings.warn("Attempting to intervene on variable {} multiple times, ...")

            msg["_intervener_id"] = self._intervener_id

            # Split node: create fresh sample site with original distribution
            new_msg = msg.copy()
            new_msg["cond_indep_stack"] = ()
            apply_stack(new_msg)

            # Apply intervention
            intervention = self.data[msg["name"]]
            msg["name"] = msg["name"] + "__CF"  # mangle old name

            if isinstance(intervention, (numbers.Number, torch.Tensor)):
                msg["value"] = torch.tensor(intervention) if isinstance(intervention, numbers.Number) else intervention
                msg["is_observed"] = True
                msg["stop"] = True
            else:
                raise NotImplementedError(...)

I/O Contract

Parameter Type Description
data Dict[str, Union[torch.Tensor, numbers.Number]] A dict mapping sample site names to intervention values
Message Effect Description
Node splitting A new sample from the original distribution is created via apply_stack
msg["name"] Original name is mangled with "__CF" suffix
msg["value"] Set to the intervention value (converted to tensor if needed)
msg["is_observed"] Set to True
msg["stop"] Set to True to prevent further processing
msg["_intervener_id"] Set to prevent recursive re-intervention

Usage Examples

Basic Intervention

def model(x):
    s = pyro.param("s", torch.tensor(0.5))
    z = pyro.sample("z", dist.Normal(x, s))
    return z ** 2

# Intervene: set z = 1.0 regardless of the distribution
intervened_model = pyro.poutine.do(model, data={"z": torch.tensor(1.)})
# This is equivalent to replacing z = pyro.sample("z", ...) with z = tensor(1.)
# while recording a fresh sample from Normal(x, s) under the name "z"

Counterfactual Reasoning

# Compose do() with condition() for counterfactual analysis
counterfactual_model = pyro.poutine.do(
    pyro.poutine.condition(model, data={"outcome": observed_outcome}),
    data={"treatment": torch.tensor(1.0)}
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment