Implementation:Pyro ppl Pyro LiftMessenger
| Attribute | Value |
|---|---|
| File | pyro/poutine/lift_messenger.py
|
| Module | pyro.poutine.lift_messenger
|
| Lines | 135 |
| Parent Class | Messenger
|
| Purpose | Replace param statements with sample statements using specified prior distributions |
| License | Apache-2.0 (Uber Technologies, Inc.) |
Overview
LiftMessenger converts pyro.param calls into pyro.sample calls by replacing the parameter's fixed value with a sample from a specified prior distribution. This is useful for turning deterministic parameters into random variables -- for example, when performing Bayesian inference over model parameters.
The prior can be specified as:
- A Distribution instance -- applied to all param sites.
- A dict mapping param names to Distributions or callables -- applied selectively by name.
- A callable -- called as a stochastic function for each param site.
The handler also implements a caching mechanism (_samples_cache) to ensure that multiple pyro.param calls with the same name receive the same sampled value within a single execution. When validation is enabled, it warns if prior dict keys do not match any param names encountered.
Code Reference
class LiftMessenger(Messenger):
def __init__(self, prior):
super().__init__()
self.prior = prior
self._samples_cache = {}
def __enter__(self):
self._samples_cache = {}
if is_validation_enabled() and isinstance(self.prior, dict):
self._param_hits = set()
self._param_misses = set()
return super().__enter__()
def __exit__(self, *args, **kwargs):
self._samples_cache = {}
if is_validation_enabled() and isinstance(self.prior, dict):
extra = set(self.prior) - self._param_hits
if extra:
warnings.warn("pyro.module prior did not find params ['{}']. ...".format(...))
return super().__exit__(*args, **kwargs)
def _pyro_param(self, msg):
name = msg["name"]
param_name = params.user_param_name(name)
if isinstance(self.prior, dict):
if param_name in self.prior.keys():
msg["fn"] = self.prior[param_name]
# ... adjust args/kwargs/infer
elif isinstance(self.prior, Distribution):
msg["fn"] = self.prior
msg["args"] = ()
msg["kwargs"] = {}
msg["infer"] = {}
elif callable(self.prior):
msg["stop"] = True
msg["fn"] = self.prior
msg["args"] = msg["args"][1:]
msg["type"] = "sample"
if name in self._samples_cache:
msg["value"] = self._samples_cache[name]["value"]
msg["is_observed"] = True
msg["stop"] = True
else:
self._samples_cache[name] = msg
msg["is_observed"] = False
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| prior | Distribution, Dict[str, Distribution or Callable], or Callable
|
Prior distribution(s) to replace param values |
| Message Effect | Description |
|---|---|
| msg["type"] | Changed from "param" to "sample"
|
| msg["fn"] | Replaced with the prior distribution or callable |
| msg["is_observed"] | Set to False (first occurrence) or True (cached duplicate)
|
| msg["value"] | Set to cached value for duplicate param names |
Usage Examples
Lifting Parameters with Named Priors
def model(x):
s = pyro.param("s", torch.tensor(0.5))
z = pyro.sample("z", dist.Normal(x, s))
return z ** 2
lifted_model = pyro.poutine.lift(model, prior={"s": dist.Exponential(0.3)})
tr = pyro.poutine.trace(lifted_model).get_trace(0.0)
assert tr.nodes["s"]["type"] == "sample" # param is now a sample site
Lifting with a Global Prior
# Apply the same prior to all parameters
lifted_model = pyro.poutine.lift(model, prior=dist.Normal(0, 1))
Lifting with a Callable
# Use a stochastic function as prior
lifted_model = pyro.poutine.lift(model, prior=lambda name, *args: dist.Normal(0, 1).sample())
Related Pages
- Pyro_ppl_Pyro_Messenger_Base -- Parent class providing the handler protocol
- Pyro_ppl_Pyro_Poutine_Handlers -- The
poutine.lift()factory function - Pyro_ppl_Pyro_SubstituteMessenger -- Related: substitutes param values (without making them random)
- Pyro_ppl_Pyro_ConditionMessenger -- Related: conditions sample sites on fixed values