Implementation:Pyro ppl Pyro UnconditionMessenger
Appearance
| Attribute | Value |
|---|---|
| File | pyro/poutine/uncondition_messenger.py
|
| Module | pyro.poutine.uncondition_messenger
|
| Lines | 35 |
| Parent Class | Messenger
|
| Purpose | Remove observations from sample sites, forcing them to be sampled from their distributions |
| License | Apache-2.0 (Uber Technologies, Inc.) |
Overview
UnconditionMessenger is the inverse of ConditionMessenger. It forces observed sample sites to be treated as unobserved, causing them to be sampled from their distributions rather than using the observed values.
When an observed site is encountered:
is_observedis set toFalse.- The original observed value is saved in
msg["infer"]["obs"]. msg["infer"]["was_observed"]is set toTruefor downstream use.msg["value"]is set toNoneso that a fresh sample will be drawn.msg["done"]is set toFalseto allow sampling.
This handler is useful for posterior predictive sampling, Reweighted Wake Sleep, and Compiled Sequential Importance Sampling, where the model needs to generate samples at sites that are normally observed.
Code Reference
class UnconditionMessenger(Messenger):
def __init__(self) -> None:
super().__init__()
def _pyro_sample(self, msg: "Message") -> None:
if msg["is_observed"]:
msg["is_observed"] = False
msg["infer"]["was_observed"] = True
msg["infer"]["obs"] = msg["value"]
msg["value"] = None
msg["done"] = False
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| (none) | This handler takes no parameters |
| Message Effect | Description |
|---|---|
| msg["is_observed"] | Set to False for previously observed sites
|
| msg["infer"]["was_observed"] | Set to True to record that the site was originally observed
|
| msg["infer"]["obs"] | Stores the original observed value |
| msg["value"] | Set to None to trigger fresh sampling
|
| msg["done"] | Set to False to allow sampling
|
Sites that are not observed are unaffected.
Usage Examples
Posterior Predictive Sampling
def model(data):
loc = pyro.sample("loc", dist.Normal(0, 1))
pyro.sample("obs", dist.Normal(loc, 1), obs=data)
# Generate predictions by unconditioning the observations
predictive_model = pyro.poutine.uncondition(model)
# Now "obs" will be sampled from Normal(loc, 1) instead of using data
Combining with Condition
# Uncondition overrides condition
conditioned = pyro.poutine.condition(model, data={"obs": observed_data})
unconditioned = pyro.poutine.uncondition(conditioned)
# "obs" is now sampled, not conditioned
Using as Context Manager
with pyro.poutine.uncondition():
# All observed sites within this block will be sampled instead
result = model(data)
Related Pages
- Pyro_ppl_Pyro_Messenger_Base -- Parent class providing the handler protocol
- Pyro_ppl_Pyro_Poutine_Handlers -- The
poutine.uncondition()factory function - Pyro_ppl_Pyro_ConditionMessenger -- The inverse operation: converts sample sites into observed sites
- Pyro_ppl_Pyro_ReplayMessenger -- Related: replays recorded values at sites
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment