Implementation:Pyro ppl Pyro SearchInference
Appearance
| Property | Value |
|---|---|
| Implementation Type | Pattern Doc |
| Source File | examples/rsa/search_inference.py
|
| Module | examples.rsa |
| Pyro Features | TracePosterior, poutine.trace, poutine.queue, poutine.escape, poutine.replay, NonlocalExit
|
| Reference | Adapted from http://dippl.org/chapters/03-enumeration.html |
| Classes | HashingMarginal, Search, BestFirstSearch |
Overview
This file provides exact and approximate search-based inference algorithms used by the RSA (Rational Speech Act) example models. It implements three key components:
- HashingMarginal: Converts a
TracePosteriorinto aDistributionover return values by building a marginal histogram. It hashes return values (tensors, dicts, or hashable objects) to aggregate probabilities across execution traces. Supportssample(),log_prob(), andenumerate_support().
- Search: Exact inference by exhaustive enumeration of all possible program executions. Uses
poutine.queueto systematically explore all branches at each sample site. Yields (trace, log_prob) pairs.
- BestFirstSearch: Approximate inference by enumerating executions in order of decreasing probability. Uses a priority queue to explore the most probable execution paths first. Equivalent to Search if all executions are enumerated, but can be truncated for efficiency.
The memoize helper wraps functools.lru_cache for caching marginal distributions during nested inference.
Code Reference
class HashingMarginal(dist.Distribution):
def __init__(self, trace_dist, sites=None):
self.trace_dist = trace_dist
self.sites = sites if sites is not None else "_RETURN"
def _dist_and_values(self):
values_map, logits = OrderedDict(), OrderedDict()
for tr, logit in zip(self.trace_dist.exec_traces,
self.trace_dist.log_weights):
value = tr.nodes[self.sites]["value"]
value_hash = hash(value.cpu().contiguous().numpy().tobytes())
if value_hash in logits:
logits[value_hash] = logsumexp(torch.stack([logits[value_hash], logit]), dim=-1)
else:
logits[value_hash] = logit
values_map[value_hash] = value
return dist.Categorical(logits=...), values_map
class Search(TracePosterior):
def _traces(self, *args, **kwargs):
q = queue.Queue()
q.put(poutine.Trace())
p = poutine.trace(poutine.queue(self.model, queue=q, max_tries=self.max_tries))
while not q.empty():
tr = p.get_trace(*args, **kwargs)
yield tr, tr.log_prob_sum()
class BestFirstSearch(TracePosterior):
def _traces(self, *args, **kwargs):
q = queue.PriorityQueue()
q.put((0.0, poutine.Trace()))
q_fn = pqueue(self.model, queue=q)
for i in range(self.num_samples):
tr = poutine.trace(q_fn).get_trace(*args, **kwargs)
yield tr, tr.log_prob_sum()
I/O Contract
| Class | Input | Output |
|---|---|---|
| Search | Model function, args, kwargs | Iterator of (Trace, log_prob) pairs (exhaustive) |
| BestFirstSearch | Model function, num_samples | Iterator of (Trace, log_prob) pairs (ordered by probability) |
| HashingMarginal | TracePosterior instance | Distribution with sample(), log_prob(), enumerate_support()
|
Usage Examples
from search_inference import HashingMarginal, Search, BestFirstSearch, memoize
def Marginal(fn):
return memoize(lambda *args: HashingMarginal(Search(fn).run(*args)))
# Use Search for exact inference
@Marginal
def my_model(data):
x = pyro.sample("x", dist.Categorical(torch.ones(3) / 3))
pyro.factor("obs", 0.0 if x == data else -999999.0)
return x
posterior = my_model(1)
posterior.sample() # Draw a sample
posterior.log_prob(torch.tensor(1)) # Get log probability
# Use BestFirstSearch for approximate inference
marginal = HashingMarginal(BestFirstSearch(my_model, num_samples=50).run(1))
Related Pages
- Pyro_ppl_Pyro_RSA_Schelling - Schelling game using Search
- Pyro_ppl_Pyro_RSA_Generics - Generics model using Search and memoize
- Pyro_ppl_Pyro_RSA_Hyperbole - Hyperbole model using Search
- Pyro_ppl_Pyro_RSA_SemanticParsing - Semantic parsing using BestFirstSearch
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment