Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro RSA SemanticParsing

From Leeroopedia


Property Value
Implementation Type Pattern Doc
Source File examples/rsa/semantic_parsing.py
Module examples.rsa
Pyro Features pyro.sample, pyro.factor, BestFirstSearch, HashingMarginal, memoize, nested inference
Reference Adapted from http://dippl.org/examples/zSemanticPragmaticMashup.html

Overview

This file implements a Rational Speech Act (RSA) model that combines pragmatic reasoning with CCG-based compositional semantics. It demonstrates how Pyro can be used for cognitive science models of language understanding, where speakers and listeners reason about each other recursively.

The implementation includes:

  • Lexical semantics: Word meanings are represented as classes (BlondMeaning, NiceMeaning, etc.) with sem() (semantic function) and syn() (syntactic category) methods following CCG (Combinatory Categorial Grammar).
  • Compositional semantics: Words are combined using CCG rules -- function application based on syntactic types (left/right application).
  • World model: A world consists of objects with probabilistic properties (blond, nice, tall), sampled via pyro.sample.
  • Pragmatic reasoning: A literal listener interprets utterances, a speaker chooses utterances based on the literal listener, and an RSA listener reasons about the speaker.

The inference uses BestFirstSearch, which enumerates program executions ordered by probability, combined with HashingMarginal to create marginal distributions.

Code Reference

@Marginal(num_samples=100)
def literal_listener(utterance):
    m = meaning(utterance)
    world = world_prior(2, m)
    pyro.factor("world_constraint", heuristic(m(world)) * 1000)
    return world

@Marginal(num_samples=100)
def speaker(world):
    utterance = utterance_prior()
    L = literal_listener(utterance)
    pyro.sample("speaker_constraint", L, obs=world)
    return utterance

def rsa_listener(utterance, qud):
    world = world_prior(2, meaning(utterance))
    S = speaker(world)
    pyro.sample("listener_constraint", S, obs=utterance)
    return qud(world)

I/O Contract

Parameter Type Description
utterance str Natural language sentence (e.g., "all blond people are nice")
qud callable Question Under Discussion function mapping world to answer
-n / --num-samples int Number of samples for BestFirstSearch (default: 10)

Output:

  • HashingMarginal distribution over world states or QUD answers

Key sample sites:

  • {name}_blond, {name}_nice, {name}_tall: Object properties (Bernoulli)
  • ix_{c}: CCG combination choices (Categorical)
  • utterance: Speaker utterance choice (Categorical)

Usage Examples

# Compute literal listener interpretation
mll = Marginal(literal_listener_raw, num_samples=100)

def is_any_qud(world):
    return any(map(lambda obj: obj.nice, world))

result = mll("all blond people are nice", is_any_qud)
print(result())

# Compute RSA pragmatic listener interpretation
rsa = Marginal(rsa_listener, num_samples=100)
result = rsa("some of the blond people are nice", is_all_qud)
print(result())

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment