Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro InferConfigMessenger

From Leeroopedia


Attribute Value
File pyro/poutine/infer_config_messenger.py
Module pyro.poutine.infer_config_messenger
Lines 58
Parent Class Messenger
Purpose Update inference configuration at sample and param sites
License Apache-2.0 (Uber Technologies, Inc.)

Overview

InferConfigMessenger updates the infer dictionary at sample and param sites by calling a user-provided config_fn on each site. The infer dictionary holds per-site inference configuration such as enumeration strategy, auxiliary flags, and other inference algorithm parameters.

This handler intercepts both pyro.sample and pyro.param calls, applying config_fn to each message and merging the returned InferDict into the existing msg["infer"].

Code Reference

class InferConfigMessenger(Messenger):
    def __init__(self, config_fn: Callable[["Message"], "InferDict"]) -> None:
        super().__init__()
        self.config_fn = config_fn

    def _pyro_sample(self, msg: "Message") -> None:
        msg["infer"].update(self.config_fn(msg))

    def _pyro_param(self, msg: "Message") -> None:
        msg["infer"].update(self.config_fn(msg))

I/O Contract

Parameter Type Description
config_fn Callable[[Message], InferDict] A function that takes a site message and returns an inference configuration dict
Message Effect Description
msg["infer"] Updated (merged) with the dict returned by config_fn(msg) at both sample and param sites

Usage Examples

Enabling Parallel Enumeration

def my_config(msg):
    if msg["type"] == "sample" and not msg["is_observed"]:
        return {"enumerate": "parallel"}
    return {}

configured_model = pyro.poutine.infer_config(model, config_fn=my_config)

Marking Auxiliary Sites

def mark_auxiliary(msg):
    if msg["name"].startswith("aux_"):
        return {"is_auxiliary": True}
    return {}

configured_guide = pyro.poutine.infer_config(guide, config_fn=mark_auxiliary)

Using as Context Manager

with pyro.poutine.infer_config(config_fn=lambda msg: {"enumerate": "parallel"}):
    z = pyro.sample("z", dist.Categorical(probs))

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment