Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro MaskMessenger

From Leeroopedia


Attribute Value
File pyro/poutine/mask_messenger.py
Module pyro.poutine.mask_messenger
Lines 41
Parent Class Messenger
Purpose Elementwise masking of log probabilities at sample sites
License Apache-2.0 (Uber Technologies, Inc.)

Overview

MaskMessenger provides elementwise masking of log probabilities at sample sites. Given a boolean mask (a True/False value or a torch.BoolTensor), it sets the mask field on all messages passing through it. Elements where the mask is True (or 1) are included in log probability computations; elements where the mask is False (or 0) are excluded.

Unlike ScaleMessenger which multiplicatively scales log probabilities, MaskMessenger performs a binary inclusion/exclusion. The mask is combined with any existing mask via bitwise AND (&).

This handler operates on all message types (not just sample sites), as it overrides _process_message rather than _pyro_sample.

The runtime function poutine.get_mask() can be used inside models to check the current mask state and skip expensive computations when the mask is False.

Code Reference

class MaskMessenger(Messenger):
    def __init__(self, mask: Union[bool, torch.BoolTensor]) -> None:
        if isinstance(mask, torch.Tensor):
            if mask.dtype != torch.bool:
                raise ValueError("Expected mask to be a BoolTensor but got {}".format(type(mask)))
        elif mask not in (True, False):
            raise ValueError("Expected mask to be a boolean but got {}".format(type(mask)))
        super().__init__()
        self.mask = mask

    def _process_message(self, msg: "Message") -> None:
        msg["mask"] = self.mask if msg["mask"] is None else msg["mask"] & self.mask

I/O Contract

Parameter Type Description
mask bool or torch.BoolTensor A masking value: True/1 includes, False/0 excludes
Message Effect Description
msg["mask"] Set to mask if previously None; otherwise combined via bitwise AND (msg["mask"] & mask)

Validation: The constructor raises ValueError if the mask is a tensor but not of torch.bool dtype, or if it is not a valid boolean.

Usage Examples

Masking Out Missing Data

def model(data, mask):
    with pyro.poutine.mask(mask=mask):
        pyro.sample("obs", dist.Normal(0, 1), obs=data)

Boolean Mask for Prediction

# During prediction, set mask=False to skip log prob computation
with pyro.poutine.mask(mask=False):
    predictions = model(x_new)

Efficient Factor Computation

def model():
    # Check mask to skip expensive computations
    if pyro.poutine.get_mask() is not False:
        log_density = my_expensive_computation()
        pyro.factor("foo", log_density)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment