Implementation:Pyro ppl Pyro RSA SemanticParsing
Appearance
| Property | Value |
|---|---|
| Implementation Type | Pattern Doc |
| Source File | examples/rsa/semantic_parsing.py
|
| Module | examples.rsa |
| Pyro Features | pyro.sample, pyro.factor, BestFirstSearch, HashingMarginal, memoize, nested inference
|
| Reference | Adapted from http://dippl.org/examples/zSemanticPragmaticMashup.html |
Overview
This file implements a Rational Speech Act (RSA) model that combines pragmatic reasoning with CCG-based compositional semantics. It demonstrates how Pyro can be used for cognitive science models of language understanding, where speakers and listeners reason about each other recursively.
The implementation includes:
- Lexical semantics: Word meanings are represented as classes (BlondMeaning, NiceMeaning, etc.) with
sem()(semantic function) andsyn()(syntactic category) methods following CCG (Combinatory Categorial Grammar). - Compositional semantics: Words are combined using CCG rules -- function application based on syntactic types (left/right application).
- World model: A world consists of objects with probabilistic properties (blond, nice, tall), sampled via
pyro.sample. - Pragmatic reasoning: A literal listener interprets utterances, a speaker chooses utterances based on the literal listener, and an RSA listener reasons about the speaker.
The inference uses BestFirstSearch, which enumerates program executions ordered by probability, combined with HashingMarginal to create marginal distributions.
Code Reference
@Marginal(num_samples=100)
def literal_listener(utterance):
m = meaning(utterance)
world = world_prior(2, m)
pyro.factor("world_constraint", heuristic(m(world)) * 1000)
return world
@Marginal(num_samples=100)
def speaker(world):
utterance = utterance_prior()
L = literal_listener(utterance)
pyro.sample("speaker_constraint", L, obs=world)
return utterance
def rsa_listener(utterance, qud):
world = world_prior(2, meaning(utterance))
S = speaker(world)
pyro.sample("listener_constraint", S, obs=utterance)
return qud(world)
I/O Contract
| Parameter | Type | Description |
|---|---|---|
utterance |
str |
Natural language sentence (e.g., "all blond people are nice") |
qud |
callable |
Question Under Discussion function mapping world to answer |
-n / --num-samples |
int |
Number of samples for BestFirstSearch (default: 10) |
Output:
HashingMarginaldistribution over world states or QUD answers
Key sample sites:
{name}_blond,{name}_nice,{name}_tall: Object properties (Bernoulli)ix_{c}: CCG combination choices (Categorical)utterance: Speaker utterance choice (Categorical)
Usage Examples
# Compute literal listener interpretation
mll = Marginal(literal_listener_raw, num_samples=100)
def is_any_qud(world):
return any(map(lambda obj: obj.nice, world))
result = mll("all blond people are nice", is_any_qud)
print(result())
# Compute RSA pragmatic listener interpretation
rsa = Marginal(rsa_listener, num_samples=100)
result = rsa("some of the blond people are nice", is_all_qud)
print(result())
Related Pages
- Pyro_ppl_Pyro_SearchInference - Search inference utilities used by RSA models
- Pyro_ppl_Pyro_RSA_Generics - RSA model for generic statements
- Pyro_ppl_Pyro_RSA_Hyperbole - RSA model for hyperbole
- Pyro_ppl_Pyro_RSA_Schelling - Schelling coordination game
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment