Implementation:Pyro ppl Pyro EqualizeMessenger
Appearance
| Attribute | Value |
|---|---|
| File | pyro/poutine/equalize_messenger.py
|
| Module | pyro.poutine.equalize_messenger
|
| Lines | 105 |
| Parent Class | Messenger
|
| Purpose | Force matching primitive sites to share the same value |
| License | Apache-2.0 (Pyro project contributors) |
Overview
EqualizeMessenger forces multiple primitive sites to share the same value. It captures the value from the first matching site and then replays that value at all subsequent matching sites. Site names are matched using regular expressions.
The handler supports two modes:
- Default mode (
keep_dist=False): Subsequent matching sites are replaced with a maskedDeltadistribution at the captured value, effectively treating them as deterministic. This means their log probability does not contribute to the model's log joint. - Conditioning mode (
keep_dist=True): Subsequent matching sites keep their original distributions but are observed at the captured value. This is equivalent to conditioning the model on all matching sites having the same value, giving the correct unnormalized log probability density.
The handler can match both sample and param sites (controlled by the type parameter).
Code Reference
class EqualizeMessenger(Messenger):
def __init__(
self,
sites: Union[str, List[str]],
type: Optional[str] = "sample",
keep_dist: Optional[bool] = False,
) -> None:
super().__init__()
self.sites = [sites] if isinstance(sites, str) else sites
self.type = type
self.keep_dist = keep_dist
def __enter__(self):
self.value = None
return super().__enter__()
def _is_matching(self, msg: Message) -> bool:
if msg["type"] == self.type:
for site in self.sites:
if re.compile(site).fullmatch(msg["name"]) is not None:
return True
return False
def _postprocess_message(self, msg: Message) -> None:
if self.value is None and self._is_matching(msg):
self.value = msg["value"]
def _process_message(self, msg: Message) -> None:
if self.value is not None and self._is_matching(msg):
msg["value"] = self.value
if msg["type"] == "sample":
msg["is_observed"] = True
if not self.keep_dist:
msg["infer"] = {"_deterministic": True}
msg["fn"] = Delta(self.value, event_dim=msg["fn"].event_dim).mask(False)
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| sites | str or List[str]
|
Site names or regex patterns to match |
| type | str (default "sample")
|
The site type to match ("sample", "param")
|
| keep_dist | bool (default False)
|
If True, keep original distributions (conditioning mode)
|
| Message Effect | Description |
|---|---|
| First matching site | Value is captured via _postprocess_message (no modification)
|
| Subsequent matching sites | msg["value"] set to captured value; msg["is_observed"] = True; optionally msg["fn"] replaced with masked Delta
|
Usage Examples
Equalizing Sample Sites Across Categories
def per_category_model(category):
shift = pyro.param(f'{category}_shift', torch.randn(1))
mean = pyro.sample(f'{category}_mean', dist.Normal(0, 1))
std = pyro.sample(f'{category}_std', dist.LogNormal(0, 1))
return pyro.sample(f'{category}_values', dist.Normal(mean + shift, std))
def model(categories):
return {cat: per_category_model(cat) for cat in categories}
# Make all *_std sites share the same value
equal_std_model = pyro.poutine.equalize(model, '.+_std')
# Also equalize the *_shift parameters
equal_std_param_model = pyro.poutine.equalize(equal_std_model, '.+_shift', 'param')
Conditioning on Equal Values
def model():
x = pyro.sample('x', dist.Normal(0, 1))
y = pyro.sample('y', dist.Normal(5, 3))
return x, y
# Condition on x == y (correct unnormalized log density)
conditioned_model = pyro.poutine.equalize(model, ['x', 'y'], keep_dist=True)
Related Pages
- Pyro_ppl_Pyro_Messenger_Base -- Parent class providing the handler protocol
- Pyro_ppl_Pyro_Poutine_Handlers -- The
poutine.equalize()factory function - Pyro_ppl_Pyro_ConditionMessenger -- Related: conditioning on observed values
- Pyro_ppl_Pyro_ReplayMessenger -- Related: replaying values from traces
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment