Implementation:Pyro ppl Pyro SubstituteMessenger
Appearance
| Attribute | Value |
|---|---|
| File | pyro/poutine/substitute_messenger.py
|
| Module | pyro.poutine.substitute_messenger
|
| Lines | 95 |
| Parent Class | Messenger
|
| Purpose | Substitute fixed values at param sites without converting them to sample sites |
| License | Apache-2.0 (Uber Technologies, Inc.) |
Overview
SubstituteMessenger replaces parameter values at pyro.param sites with fixed values from a data dictionary, without converting them into sample sites (unlike LiftMessenger which turns params into samples).
Key features:
- Matches param site names against keys in the provided
datadict. - Uses
params.user_param_name(name)to strip module namespace prefixes when matching. - Implements a
_data_cacheto ensure that multiplepyro.paramcalls with the same name receive the same value. - When validation is enabled, tracks
_param_hitsand_param_missesand warns if data keys do not match any encountered param names.
The handler has no effect on pyro.sample sites (its _pyro_sample returns None).
Code Reference
class SubstituteMessenger(Messenger):
def __init__(self, data: Dict[str, "torch.Tensor"]) -> None:
super().__init__()
self.data = data
self._data_cache = {}
def __enter__(self):
self._data_cache = {}
if is_validation_enabled() and isinstance(self.data, dict):
self._param_hits = set()
self._param_misses = set()
return super().__enter__()
def __exit__(self, *args, **kwargs):
self._data_cache = {}
if is_validation_enabled() and isinstance(self.data, dict):
extra = set(self.data) - self._param_hits
if extra:
warnings.warn("pyro.module data did not find params ['{}']. ...".format(...))
return super().__exit__(*args, **kwargs)
def _pyro_sample(self, msg):
return None
def _pyro_param(self, msg):
name = msg["name"]
param_name = params.user_param_name(name)
if param_name in self.data.keys():
msg["value"] = self.data[param_name]
if is_validation_enabled():
self._param_hits.add(param_name)
else:
if is_validation_enabled():
self._param_misses.add(param_name)
return None
if name in self._data_cache:
msg["value"] = self._data_cache[name]["value"]
else:
self._data_cache[name] = msg
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| data | Dict[str, torch.Tensor]
|
A dict mapping param names to substitute values |
| Message Effect | Description |
|---|---|
| msg["value"] (param sites) | Set to the value from data for matching param names
|
| msg (sample sites) | No effect (returns None)
|
Validation: When validation is enabled, warns if data keys do not match any param names encountered during execution.
Usage Examples
Substituting Parameter Values
def model(x):
a = pyro.param("a", torch.tensor(0.5))
x = pyro.sample("x", dist.Bernoulli(probs=a))
return x
substituted_model = pyro.poutine.substitute(model, data={"a": torch.tensor(0.3)})
# Now site "a" will have value torch.tensor(0.3)
Using as Context Manager
with pyro.poutine.substitute(data={"weight": torch.randn(10)}):
result = model(input_data)
Related Pages
- Pyro_ppl_Pyro_Messenger_Base -- Parent class providing the handler protocol
- Pyro_ppl_Pyro_Poutine_Handlers -- The
poutine.substitute()factory function - Pyro_ppl_Pyro_LiftMessenger -- Related: converts params to samples (substitute keeps them as params)
- Pyro_ppl_Pyro_ConditionMessenger -- Related: conditions sample sites on fixed values (substitute targets param sites)
- Pyro_ppl_Pyro_ReplayMessenger -- Related: replays values from traces at both sample and param sites
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment