Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro SearchInference

From Leeroopedia


Property Value
Implementation Type Pattern Doc
Source File examples/rsa/search_inference.py
Module examples.rsa
Pyro Features TracePosterior, poutine.trace, poutine.queue, poutine.escape, poutine.replay, NonlocalExit
Reference Adapted from http://dippl.org/chapters/03-enumeration.html
Classes HashingMarginal, Search, BestFirstSearch

Overview

This file provides exact and approximate search-based inference algorithms used by the RSA (Rational Speech Act) example models. It implements three key components:

  • HashingMarginal: Converts a TracePosterior into a Distribution over return values by building a marginal histogram. It hashes return values (tensors, dicts, or hashable objects) to aggregate probabilities across execution traces. Supports sample(), log_prob(), and enumerate_support().
  • Search: Exact inference by exhaustive enumeration of all possible program executions. Uses poutine.queue to systematically explore all branches at each sample site. Yields (trace, log_prob) pairs.
  • BestFirstSearch: Approximate inference by enumerating executions in order of decreasing probability. Uses a priority queue to explore the most probable execution paths first. Equivalent to Search if all executions are enumerated, but can be truncated for efficiency.

The memoize helper wraps functools.lru_cache for caching marginal distributions during nested inference.

Code Reference

class HashingMarginal(dist.Distribution):
    def __init__(self, trace_dist, sites=None):
        self.trace_dist = trace_dist
        self.sites = sites if sites is not None else "_RETURN"

    def _dist_and_values(self):
        values_map, logits = OrderedDict(), OrderedDict()
        for tr, logit in zip(self.trace_dist.exec_traces,
                             self.trace_dist.log_weights):
            value = tr.nodes[self.sites]["value"]
            value_hash = hash(value.cpu().contiguous().numpy().tobytes())
            if value_hash in logits:
                logits[value_hash] = logsumexp(torch.stack([logits[value_hash], logit]), dim=-1)
            else:
                logits[value_hash] = logit
                values_map[value_hash] = value
        return dist.Categorical(logits=...), values_map

class Search(TracePosterior):
    def _traces(self, *args, **kwargs):
        q = queue.Queue()
        q.put(poutine.Trace())
        p = poutine.trace(poutine.queue(self.model, queue=q, max_tries=self.max_tries))
        while not q.empty():
            tr = p.get_trace(*args, **kwargs)
            yield tr, tr.log_prob_sum()

class BestFirstSearch(TracePosterior):
    def _traces(self, *args, **kwargs):
        q = queue.PriorityQueue()
        q.put((0.0, poutine.Trace()))
        q_fn = pqueue(self.model, queue=q)
        for i in range(self.num_samples):
            tr = poutine.trace(q_fn).get_trace(*args, **kwargs)
            yield tr, tr.log_prob_sum()

I/O Contract

Class Input Output
Search Model function, args, kwargs Iterator of (Trace, log_prob) pairs (exhaustive)
BestFirstSearch Model function, num_samples Iterator of (Trace, log_prob) pairs (ordered by probability)
HashingMarginal TracePosterior instance Distribution with sample(), log_prob(), enumerate_support()

Usage Examples

from search_inference import HashingMarginal, Search, BestFirstSearch, memoize

def Marginal(fn):
    return memoize(lambda *args: HashingMarginal(Search(fn).run(*args)))

# Use Search for exact inference
@Marginal
def my_model(data):
    x = pyro.sample("x", dist.Categorical(torch.ones(3) / 3))
    pyro.factor("obs", 0.0 if x == data else -999999.0)
    return x

posterior = my_model(1)
posterior.sample()  # Draw a sample
posterior.log_prob(torch.tensor(1))  # Get log probability

# Use BestFirstSearch for approximate inference
marginal = HashingMarginal(BestFirstSearch(my_model, num_samples=50).run(1))

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment