Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro EqualizeMessenger

From Leeroopedia
Revision as of 16:23, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Pyro_ppl_Pyro_EqualizeMessenger.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Attribute Value
File pyro/poutine/equalize_messenger.py
Module pyro.poutine.equalize_messenger
Lines 105
Parent Class Messenger
Purpose Force matching primitive sites to share the same value
License Apache-2.0 (Pyro project contributors)

Overview

EqualizeMessenger forces multiple primitive sites to share the same value. It captures the value from the first matching site and then replays that value at all subsequent matching sites. Site names are matched using regular expressions.

The handler supports two modes:

  • Default mode (keep_dist=False): Subsequent matching sites are replaced with a masked Delta distribution at the captured value, effectively treating them as deterministic. This means their log probability does not contribute to the model's log joint.
  • Conditioning mode (keep_dist=True): Subsequent matching sites keep their original distributions but are observed at the captured value. This is equivalent to conditioning the model on all matching sites having the same value, giving the correct unnormalized log probability density.

The handler can match both sample and param sites (controlled by the type parameter).

Code Reference

class EqualizeMessenger(Messenger):
    def __init__(
        self,
        sites: Union[str, List[str]],
        type: Optional[str] = "sample",
        keep_dist: Optional[bool] = False,
    ) -> None:
        super().__init__()
        self.sites = [sites] if isinstance(sites, str) else sites
        self.type = type
        self.keep_dist = keep_dist

    def __enter__(self):
        self.value = None
        return super().__enter__()

    def _is_matching(self, msg: Message) -> bool:
        if msg["type"] == self.type:
            for site in self.sites:
                if re.compile(site).fullmatch(msg["name"]) is not None:
                    return True
        return False

    def _postprocess_message(self, msg: Message) -> None:
        if self.value is None and self._is_matching(msg):
            self.value = msg["value"]

    def _process_message(self, msg: Message) -> None:
        if self.value is not None and self._is_matching(msg):
            msg["value"] = self.value
            if msg["type"] == "sample":
                msg["is_observed"] = True
                if not self.keep_dist:
                    msg["infer"] = {"_deterministic": True}
                    msg["fn"] = Delta(self.value, event_dim=msg["fn"].event_dim).mask(False)

I/O Contract

Parameter Type Description
sites str or List[str] Site names or regex patterns to match
type str (default "sample") The site type to match ("sample", "param")
keep_dist bool (default False) If True, keep original distributions (conditioning mode)
Message Effect Description
First matching site Value is captured via _postprocess_message (no modification)
Subsequent matching sites msg["value"] set to captured value; msg["is_observed"] = True; optionally msg["fn"] replaced with masked Delta

Usage Examples

Equalizing Sample Sites Across Categories

def per_category_model(category):
    shift = pyro.param(f'{category}_shift', torch.randn(1))
    mean = pyro.sample(f'{category}_mean', dist.Normal(0, 1))
    std = pyro.sample(f'{category}_std', dist.LogNormal(0, 1))
    return pyro.sample(f'{category}_values', dist.Normal(mean + shift, std))

def model(categories):
    return {cat: per_category_model(cat) for cat in categories}

# Make all *_std sites share the same value
equal_std_model = pyro.poutine.equalize(model, '.+_std')

# Also equalize the *_shift parameters
equal_std_param_model = pyro.poutine.equalize(equal_std_model, '.+_shift', 'param')

Conditioning on Equal Values

def model():
    x = pyro.sample('x', dist.Normal(0, 1))
    y = pyro.sample('y', dist.Normal(5, 3))
    return x, y

# Condition on x == y (correct unnormalized log density)
conditioned_model = pyro.poutine.equalize(model, ['x', 'y'], keep_dist=True)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment