Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro ConditionMessenger

From Leeroopedia
Revision as of 16:23, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Pyro_ppl_Pyro_ConditionMessenger.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Attribute Value
File pyro/poutine/condition_messenger.py
Module pyro.poutine.condition_messenger
Lines 72
Parent Class Messenger
Purpose Convert sample sites into observed sites with given values
License Apache-2.0 (Uber Technologies, Inc.)

Overview

ConditionMessenger converts sample sites into observed sites by injecting fixed values from a data dictionary or a Trace object. When a sample site's name matches a key in the provided data, the site's value is set to the data value and is_observed is set to True.

This is equivalent to adding obs=value as a keyword argument to pyro.sample(), but applied externally. This handler is fundamental to Bayesian inference as it specifies which variables are observed.

The data source can be either:

  • A dict mapping site names to torch.Tensor values.
  • A Trace object, from which values are extracted via trace.nodes[name]["value"].

Code Reference

class ConditionMessenger(Messenger):
    def __init__(self, data: Union[Dict[str, torch.Tensor], Trace]) -> None:
        super().__init__()
        self.data = data

    def _pyro_sample(self, msg: "Message") -> None:
        name = msg["name"]
        if name in self.data:
            if isinstance(self.data, Trace):
                msg["value"] = self.data.nodes[name]["value"]
            else:
                msg["value"] = self.data[name]
            msg["is_observed"] = msg["value"] is not None

I/O Contract

Parameter Type Description
data Dict[str, torch.Tensor] or Trace The observations to condition on, keyed by site name
Message Effect Description
msg["value"] Set to the value from data for matching sites
msg["is_observed"] Set to True when value is not None

Sites not present in data are unaffected (default sampling behavior).

Usage Examples

Conditioning with a Dictionary

def model(x):
    s = pyro.param("s", torch.tensor(0.5))
    z = pyro.sample("z", dist.Normal(x, s))
    return z ** 2

conditioned_model = pyro.poutine.condition(model, data={"z": torch.tensor(1.)})

Conditioning with a Trace

old_trace = pyro.poutine.trace(model).get_trace(0.0)
conditioned_model = pyro.poutine.condition(model, data=old_trace)

Using as Context Manager

with pyro.poutine.condition(data={"z": torch.tensor(1.0)}):
    result = model(0.0)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment