Implementation:Pyro ppl Pyro MaskMessenger
| Attribute | Value |
|---|---|
| File | pyro/poutine/mask_messenger.py
|
| Module | pyro.poutine.mask_messenger
|
| Lines | 41 |
| Parent Class | Messenger
|
| Purpose | Elementwise masking of log probabilities at sample sites |
| License | Apache-2.0 (Uber Technologies, Inc.) |
Overview
MaskMessenger provides elementwise masking of log probabilities at sample sites. Given a boolean mask (a True/False value or a torch.BoolTensor), it sets the mask field on all messages passing through it. Elements where the mask is True (or 1) are included in log probability computations; elements where the mask is False (or 0) are excluded.
Unlike ScaleMessenger which multiplicatively scales log probabilities, MaskMessenger performs a binary inclusion/exclusion. The mask is combined with any existing mask via bitwise AND (&).
This handler operates on all message types (not just sample sites), as it overrides _process_message rather than _pyro_sample.
The runtime function poutine.get_mask() can be used inside models to check the current mask state and skip expensive computations when the mask is False.
Code Reference
class MaskMessenger(Messenger):
def __init__(self, mask: Union[bool, torch.BoolTensor]) -> None:
if isinstance(mask, torch.Tensor):
if mask.dtype != torch.bool:
raise ValueError("Expected mask to be a BoolTensor but got {}".format(type(mask)))
elif mask not in (True, False):
raise ValueError("Expected mask to be a boolean but got {}".format(type(mask)))
super().__init__()
self.mask = mask
def _process_message(self, msg: "Message") -> None:
msg["mask"] = self.mask if msg["mask"] is None else msg["mask"] & self.mask
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| mask | bool or torch.BoolTensor
|
A masking value: True/1 includes, False/0 excludes
|
| Message Effect | Description |
|---|---|
| msg["mask"] | Set to mask if previously None; otherwise combined via bitwise AND (msg["mask"] & mask)
|
Validation: The constructor raises ValueError if the mask is a tensor but not of torch.bool dtype, or if it is not a valid boolean.
Usage Examples
Masking Out Missing Data
def model(data, mask):
with pyro.poutine.mask(mask=mask):
pyro.sample("obs", dist.Normal(0, 1), obs=data)
Boolean Mask for Prediction
# During prediction, set mask=False to skip log prob computation
with pyro.poutine.mask(mask=False):
predictions = model(x_new)
Efficient Factor Computation
def model():
# Check mask to skip expensive computations
if pyro.poutine.get_mask() is not False:
log_density = my_expensive_computation()
pyro.factor("foo", log_density)
Related Pages
- Pyro_ppl_Pyro_Messenger_Base -- Parent class providing the handler protocol
- Pyro_ppl_Pyro_Poutine_Handlers -- The
poutine.mask()factory function - Pyro_ppl_Pyro_ScaleMessenger -- Related: multiplicative scaling of log probabilities
- Pyro_ppl_Pyro_Poutine_Runtime --
get_mask()function for checking current mask - Pyro_ppl_Pyro_Trace_Struct -- Uses mask in
scale_and_mask()for log prob computation