Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro LiftMessenger

From Leeroopedia


Attribute Value
File pyro/poutine/lift_messenger.py
Module pyro.poutine.lift_messenger
Lines 135
Parent Class Messenger
Purpose Replace param statements with sample statements using specified prior distributions
License Apache-2.0 (Uber Technologies, Inc.)

Overview

LiftMessenger converts pyro.param calls into pyro.sample calls by replacing the parameter's fixed value with a sample from a specified prior distribution. This is useful for turning deterministic parameters into random variables -- for example, when performing Bayesian inference over model parameters.

The prior can be specified as:

  • A Distribution instance -- applied to all param sites.
  • A dict mapping param names to Distributions or callables -- applied selectively by name.
  • A callable -- called as a stochastic function for each param site.

The handler also implements a caching mechanism (_samples_cache) to ensure that multiple pyro.param calls with the same name receive the same sampled value within a single execution. When validation is enabled, it warns if prior dict keys do not match any param names encountered.

Code Reference

class LiftMessenger(Messenger):
    def __init__(self, prior):
        super().__init__()
        self.prior = prior
        self._samples_cache = {}

    def __enter__(self):
        self._samples_cache = {}
        if is_validation_enabled() and isinstance(self.prior, dict):
            self._param_hits = set()
            self._param_misses = set()
        return super().__enter__()

    def __exit__(self, *args, **kwargs):
        self._samples_cache = {}
        if is_validation_enabled() and isinstance(self.prior, dict):
            extra = set(self.prior) - self._param_hits
            if extra:
                warnings.warn("pyro.module prior did not find params ['{}']. ...".format(...))
        return super().__exit__(*args, **kwargs)

    def _pyro_param(self, msg):
        name = msg["name"]
        param_name = params.user_param_name(name)

        if isinstance(self.prior, dict):
            if param_name in self.prior.keys():
                msg["fn"] = self.prior[param_name]
                # ... adjust args/kwargs/infer
        elif isinstance(self.prior, Distribution):
            msg["fn"] = self.prior
            msg["args"] = ()
            msg["kwargs"] = {}
            msg["infer"] = {}
        elif callable(self.prior):
            msg["stop"] = True
            msg["fn"] = self.prior
            msg["args"] = msg["args"][1:]

        msg["type"] = "sample"
        if name in self._samples_cache:
            msg["value"] = self._samples_cache[name]["value"]
            msg["is_observed"] = True
            msg["stop"] = True
        else:
            self._samples_cache[name] = msg
            msg["is_observed"] = False

I/O Contract

Parameter Type Description
prior Distribution, Dict[str, Distribution or Callable], or Callable Prior distribution(s) to replace param values
Message Effect Description
msg["type"] Changed from "param" to "sample"
msg["fn"] Replaced with the prior distribution or callable
msg["is_observed"] Set to False (first occurrence) or True (cached duplicate)
msg["value"] Set to cached value for duplicate param names

Usage Examples

Lifting Parameters with Named Priors

def model(x):
    s = pyro.param("s", torch.tensor(0.5))
    z = pyro.sample("z", dist.Normal(x, s))
    return z ** 2

lifted_model = pyro.poutine.lift(model, prior={"s": dist.Exponential(0.3)})

tr = pyro.poutine.trace(lifted_model).get_trace(0.0)
assert tr.nodes["s"]["type"] == "sample"  # param is now a sample site

Lifting with a Global Prior

# Apply the same prior to all parameters
lifted_model = pyro.poutine.lift(model, prior=dist.Normal(0, 1))

Lifting with a Callable

# Use a stochastic function as prior
lifted_model = pyro.poutine.lift(model, prior=lambda name, *args: dist.Normal(0, 1).sample())

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment