Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro SubstituteMessenger

From Leeroopedia


Attribute Value
File pyro/poutine/substitute_messenger.py
Module pyro.poutine.substitute_messenger
Lines 95
Parent Class Messenger
Purpose Substitute fixed values at param sites without converting them to sample sites
License Apache-2.0 (Uber Technologies, Inc.)

Overview

SubstituteMessenger replaces parameter values at pyro.param sites with fixed values from a data dictionary, without converting them into sample sites (unlike LiftMessenger which turns params into samples).

Key features:

  • Matches param site names against keys in the provided data dict.
  • Uses params.user_param_name(name) to strip module namespace prefixes when matching.
  • Implements a _data_cache to ensure that multiple pyro.param calls with the same name receive the same value.
  • When validation is enabled, tracks _param_hits and _param_misses and warns if data keys do not match any encountered param names.

The handler has no effect on pyro.sample sites (its _pyro_sample returns None).

Code Reference

class SubstituteMessenger(Messenger):
    def __init__(self, data: Dict[str, "torch.Tensor"]) -> None:
        super().__init__()
        self.data = data
        self._data_cache = {}

    def __enter__(self):
        self._data_cache = {}
        if is_validation_enabled() and isinstance(self.data, dict):
            self._param_hits = set()
            self._param_misses = set()
        return super().__enter__()

    def __exit__(self, *args, **kwargs):
        self._data_cache = {}
        if is_validation_enabled() and isinstance(self.data, dict):
            extra = set(self.data) - self._param_hits
            if extra:
                warnings.warn("pyro.module data did not find params ['{}']. ...".format(...))
        return super().__exit__(*args, **kwargs)

    def _pyro_sample(self, msg):
        return None

    def _pyro_param(self, msg):
        name = msg["name"]
        param_name = params.user_param_name(name)
        if param_name in self.data.keys():
            msg["value"] = self.data[param_name]
            if is_validation_enabled():
                self._param_hits.add(param_name)
        else:
            if is_validation_enabled():
                self._param_misses.add(param_name)
            return None
        if name in self._data_cache:
            msg["value"] = self._data_cache[name]["value"]
        else:
            self._data_cache[name] = msg

I/O Contract

Parameter Type Description
data Dict[str, torch.Tensor] A dict mapping param names to substitute values
Message Effect Description
msg["value"] (param sites) Set to the value from data for matching param names
msg (sample sites) No effect (returns None)

Validation: When validation is enabled, warns if data keys do not match any param names encountered during execution.

Usage Examples

Substituting Parameter Values

def model(x):
    a = pyro.param("a", torch.tensor(0.5))
    x = pyro.sample("x", dist.Bernoulli(probs=a))
    return x

substituted_model = pyro.poutine.substitute(model, data={"a": torch.tensor(0.3)})
# Now site "a" will have value torch.tensor(0.3)

Using as Context Manager

with pyro.poutine.substitute(data={"weight": torch.randn(10)}):
    result = model(input_data)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment