Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro ReparamMessenger

From Leeroopedia


Attribute Value
File pyro/poutine/reparam_messenger.py
Module pyro.poutine.reparam_messenger
Lines 166
Parent Class Messenger
Purpose Apply reparameterization strategies to transform sample sites
Architecture Role Bridge between the Poutine system and the pyro.infer.reparam module
License Apache-2.0 (Uber Technologies, Inc.)

Overview

ReparamMessenger reparameterizes sample sites by transforming them into one or more auxiliary sample sites followed by a deterministic transformation. This follows the approach described in Gorinova, Moore, and Hoffman (2019), "Automatic Reparameterisation of Probabilistic Programs."

The handler accepts a config that maps site names to Reparam instances. When a sample site matches, the reparameterizer's apply() method is called to transform the site's distribution, value, and observation status.

Key features:

  • Dict or callable config -- The config can be a Dict[str, Reparam] for name-based lookup, or a Callable[[Message], Optional[Reparam]] for dynamic selection.
  • ReparamHandler wrapper -- When used as a decorator (__call__), it returns a ReparamHandler that passes model arguments to reparameterizers via a side channel (_args_kwargs).
  • Init messenger integration -- The handler applies enclosing InitMessenger instances early, enabling reparameterized sites to work correctly with initialization strategies.

Code Reference

ReparamMessenger

class ReparamMessenger(Messenger):
    def __init__(self, config):
        super().__init__()
        assert isinstance(config, dict) or callable(config)
        self.config = config
        self._args_kwargs = None

    def __call__(self, fn):
        return ReparamHandler(self, fn)

    def _pyro_sample(self, msg):
        if type(msg["fn"]).__name__ == "_Subsample":
            return
        if isinstance(self.config, dict):
            reparam = self.config.get(msg["name"])
        else:
            reparam = self.config(msg)
        if reparam is None:
            return

        # Apply enclosing InitMessengers early for correct initialization
        for m in _get_init_messengers():
            m._process_message(msg)

        reparam.args_kwargs = self._args_kwargs
        try:
            new_msg = reparam.apply({
                "name": msg["name"],
                "fn": msg["fn"],
                "value": msg["value"],
                "is_observed": msg["is_observed"],
            })
        finally:
            reparam.args_kwargs = None

        msg["fn"] = new_msg["fn"]
        msg["value"] = new_msg["value"]
        msg["is_observed"] = new_msg["is_observed"]

ReparamHandler

class ReparamHandler(Generic[_P, _T]):
    """Reparameterization poutine that passes args_kwargs to reparameterizers."""

    def __init__(self, msngr, fn):
        self.msngr = msngr
        self.fn = fn

    def __call__(self, *args, **kwargs):
        self.msngr._args_kwargs = args, kwargs
        try:
            with self.msngr:
                return self.fn(*args, **kwargs)
        finally:
            self.msngr._args_kwargs = None

I/O Contract

Parameter Type Description
config Dict[str, Reparam] or Callable[[Message], Optional[Reparam]] Reparameterization configuration mapping sites to reparameterizers
Message Effect Description
msg["fn"] Replaced with the reparameterized distribution
msg["value"] Replaced with the reparameterized value (may be None)
msg["is_observed"] Updated based on the reparameterization result

Usage Examples

Using with a Dict Config

from pyro.infer.reparam import TransformReparam

reparam_model = pyro.poutine.reparam(
    model,
    config={"z": TransformReparam()}
)

Using with a Callable Config

from pyro.infer.reparam import LocScaleReparam

def my_config(msg):
    if msg["name"].startswith("loc_"):
        return LocScaleReparam()
    return None

reparam_model = pyro.poutine.reparam(model, config=my_config)

As a Decorator (Passing Args to Reparameterizers)

# Decorator usage preserves args/kwargs for reparameterizers
reparam_model = pyro.poutine.reparam(model, config={"z": MyReparam()})
result = reparam_model(input_data)  # input_data accessible to MyReparam

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment