Implementation:Pyro ppl Pyro InferConfigMessenger
Appearance
| Attribute | Value |
|---|---|
| File | pyro/poutine/infer_config_messenger.py
|
| Module | pyro.poutine.infer_config_messenger
|
| Lines | 58 |
| Parent Class | Messenger
|
| Purpose | Update inference configuration at sample and param sites |
| License | Apache-2.0 (Uber Technologies, Inc.) |
Overview
InferConfigMessenger updates the infer dictionary at sample and param sites by calling a user-provided config_fn on each site. The infer dictionary holds per-site inference configuration such as enumeration strategy, auxiliary flags, and other inference algorithm parameters.
This handler intercepts both pyro.sample and pyro.param calls, applying config_fn to each message and merging the returned InferDict into the existing msg["infer"].
Code Reference
class InferConfigMessenger(Messenger):
def __init__(self, config_fn: Callable[["Message"], "InferDict"]) -> None:
super().__init__()
self.config_fn = config_fn
def _pyro_sample(self, msg: "Message") -> None:
msg["infer"].update(self.config_fn(msg))
def _pyro_param(self, msg: "Message") -> None:
msg["infer"].update(self.config_fn(msg))
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| config_fn | Callable[[Message], InferDict]
|
A function that takes a site message and returns an inference configuration dict |
| Message Effect | Description |
|---|---|
| msg["infer"] | Updated (merged) with the dict returned by config_fn(msg) at both sample and param sites
|
Usage Examples
Enabling Parallel Enumeration
def my_config(msg):
if msg["type"] == "sample" and not msg["is_observed"]:
return {"enumerate": "parallel"}
return {}
configured_model = pyro.poutine.infer_config(model, config_fn=my_config)
Marking Auxiliary Sites
def mark_auxiliary(msg):
if msg["name"].startswith("aux_"):
return {"is_auxiliary": True}
return {}
configured_guide = pyro.poutine.infer_config(guide, config_fn=mark_auxiliary)
Using as Context Manager
with pyro.poutine.infer_config(config_fn=lambda msg: {"enumerate": "parallel"}):
z = pyro.sample("z", dist.Categorical(probs))
Related Pages
- Pyro_ppl_Pyro_Messenger_Base -- Parent class providing the handler protocol
- Pyro_ppl_Pyro_Poutine_Handlers -- The
poutine.infer_config()factory function - Pyro_ppl_Pyro_Poutine_Runtime -- Defines
InferDicttype used by this messenger
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment