Implementation:Pyro ppl Pyro ReparamMessenger
Appearance
| Attribute | Value |
|---|---|
| File | pyro/poutine/reparam_messenger.py
|
| Module | pyro.poutine.reparam_messenger
|
| Lines | 166 |
| Parent Class | Messenger
|
| Purpose | Apply reparameterization strategies to transform sample sites |
| Architecture Role | Bridge between the Poutine system and the pyro.infer.reparam module
|
| License | Apache-2.0 (Uber Technologies, Inc.) |
Overview
ReparamMessenger reparameterizes sample sites by transforming them into one or more auxiliary sample sites followed by a deterministic transformation. This follows the approach described in Gorinova, Moore, and Hoffman (2019), "Automatic Reparameterisation of Probabilistic Programs."
The handler accepts a config that maps site names to Reparam instances. When a sample site matches, the reparameterizer's apply() method is called to transform the site's distribution, value, and observation status.
Key features:
- Dict or callable config -- The config can be a
Dict[str, Reparam]for name-based lookup, or aCallable[[Message], Optional[Reparam]]for dynamic selection. - ReparamHandler wrapper -- When used as a decorator (
__call__), it returns aReparamHandlerthat passes model arguments to reparameterizers via a side channel (_args_kwargs). - Init messenger integration -- The handler applies enclosing
InitMessengerinstances early, enabling reparameterized sites to work correctly with initialization strategies.
Code Reference
ReparamMessenger
class ReparamMessenger(Messenger):
def __init__(self, config):
super().__init__()
assert isinstance(config, dict) or callable(config)
self.config = config
self._args_kwargs = None
def __call__(self, fn):
return ReparamHandler(self, fn)
def _pyro_sample(self, msg):
if type(msg["fn"]).__name__ == "_Subsample":
return
if isinstance(self.config, dict):
reparam = self.config.get(msg["name"])
else:
reparam = self.config(msg)
if reparam is None:
return
# Apply enclosing InitMessengers early for correct initialization
for m in _get_init_messengers():
m._process_message(msg)
reparam.args_kwargs = self._args_kwargs
try:
new_msg = reparam.apply({
"name": msg["name"],
"fn": msg["fn"],
"value": msg["value"],
"is_observed": msg["is_observed"],
})
finally:
reparam.args_kwargs = None
msg["fn"] = new_msg["fn"]
msg["value"] = new_msg["value"]
msg["is_observed"] = new_msg["is_observed"]
ReparamHandler
class ReparamHandler(Generic[_P, _T]):
"""Reparameterization poutine that passes args_kwargs to reparameterizers."""
def __init__(self, msngr, fn):
self.msngr = msngr
self.fn = fn
def __call__(self, *args, **kwargs):
self.msngr._args_kwargs = args, kwargs
try:
with self.msngr:
return self.fn(*args, **kwargs)
finally:
self.msngr._args_kwargs = None
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| config | Dict[str, Reparam] or Callable[[Message], Optional[Reparam]]
|
Reparameterization configuration mapping sites to reparameterizers |
| Message Effect | Description |
|---|---|
| msg["fn"] | Replaced with the reparameterized distribution |
| msg["value"] | Replaced with the reparameterized value (may be None)
|
| msg["is_observed"] | Updated based on the reparameterization result |
Usage Examples
Using with a Dict Config
from pyro.infer.reparam import TransformReparam
reparam_model = pyro.poutine.reparam(
model,
config={"z": TransformReparam()}
)
Using with a Callable Config
from pyro.infer.reparam import LocScaleReparam
def my_config(msg):
if msg["name"].startswith("loc_"):
return LocScaleReparam()
return None
reparam_model = pyro.poutine.reparam(model, config=my_config)
As a Decorator (Passing Args to Reparameterizers)
# Decorator usage preserves args/kwargs for reparameterizers
reparam_model = pyro.poutine.reparam(model, config={"z": MyReparam()})
result = reparam_model(input_data) # input_data accessible to MyReparam
Related Pages
- Pyro_ppl_Pyro_Messenger_Base -- Parent class providing the handler protocol
- Pyro_ppl_Pyro_Poutine_Handlers -- The
poutine.reparam()factory function - Pyro_ppl_Pyro_Poutine_Runtime --
effectfuldecorator used for_get_init_messengers - Pyro_ppl_Pyro_SubstituteMessenger -- Related: both modify site behavior
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment