Principle:Pyro ppl Pyro Minimal PPL
| Knowledge Sources | |
|---|---|
| Domains | Probabilistic Programming, Language Design, Education |
| Last Updated | 2026-02-09 09:00 GMT |
Overview
A minimal probabilistic programming language distills the core abstractions of a PPL -- sample, observe, effect handlers, and inference -- into a small, self-contained implementation that serves as both educational material and a reference specification.
Description
Full probabilistic programming systems like Pyro involve thousands of lines of code handling edge cases, performance optimizations, and feature interactions. A minimal PPL strips away all of this complexity to reveal the essential design patterns.
The core insight is that a PPL needs only a few primitives:
sample(name, distribution): Draw a value from a probability distribution and record it. This is the fundamental stochastic primitive.
param(name, initial_value): Declare a learnable parameter. During inference, these parameters are optimized.
plate(name, size): Declare a conditional independence context for vectorized computation.
Effect handler stack: A global stack of context managers that intercept sample and param calls. Each handler can modify the behavior (e.g., condition on observed values, record a trace, substitute values).
Trace handler: Records all sample sites and their values, distributions, and log-probabilities.
SVI (Stochastic Variational Inference): The simplest inference algorithm, which optimizes ELBO using stochastic gradients.
A minimal PPL demonstrates that the full power of probabilistic programming emerges from the interaction of these simple components. The effect handler pattern is the key architectural insight: by changing what happens when sample is called, the same model function can be used for prior sampling, posterior inference, conditioning, and prediction.
Usage
Use the minimal PPL when:
- Learning how probabilistic programming languages work internally.
- Teaching the core concepts of PPLs without the complexity of a full system.
- Prototyping new inference algorithms or language features.
- Understanding Pyro's architecture by studying its simplified core.
- Building a custom PPL for a specialized domain.
Theoretical Basis
Core primitives:
# The entire PPL is built from these primitives:
# 1. sample: stochastic choice
def sample(name, dist, obs=None):
# Apply all active effect handlers
value = apply_stack(msg={
"type": "sample",
"name": name,
"fn": dist,
"value": obs,
})
return value
# 2. param: learnable parameter
def param(name, initial_value):
value = apply_stack(msg={
"type": "param",
"name": name,
"value": initial_value,
})
return value
# 3. apply_stack: the effect handler mechanism
HANDLER_STACK = []
def apply_stack(msg):
# Walk handlers from top to bottom (process)
for handler in reversed(HANDLER_STACK):
handler.process_message(msg)
if msg.get("stop"):
break
# Apply default behavior
if msg["value"] is None:
msg["value"] = msg["fn"].sample()
# Walk handlers from bottom to top (postprocess)
for handler in HANDLER_STACK:
handler.postprocess_message(msg)
return msg["value"]
Effect handlers:
# Base handler (context manager that pushes/pops from stack):
class Handler:
def __enter__(self):
HANDLER_STACK.append(self)
return self
def __exit__(self, *args):
HANDLER_STACK.pop()
def process_message(self, msg):
pass
def postprocess_message(self, msg):
pass
# Trace handler: records all sites
class Trace(Handler):
def __init__(self):
self.trace = {}
def postprocess_message(self, msg):
self.trace[msg["name"]] = msg.copy()
msg["log_prob"] = msg["fn"].log_prob(msg["value"])
# Condition handler: fix observed values
class Condition(Handler):
def __init__(self, data):
self.data = data
def process_message(self, msg):
if msg["name"] in self.data:
msg["value"] = self.data[msg["name"]]
Minimal SVI:
# ELBO estimation:
def elbo(model, guide, *args):
# Trace the guide
with Trace() as guide_trace:
guide(*args)
# Replay model with guide values, trace model
with Trace() as model_trace:
with Replay(guide_trace.trace):
model(*args)
# Compute ELBO
loss = 0
for name, site in model_trace.trace.items():
loss = loss + site["log_prob"]
for name, site in guide_trace.trace.items():
if site["type"] == "sample":
loss = loss - site["log_prob"]
return -loss # negate for minimization
# SVI loop:
# for step in range(num_steps):
# loss = elbo(model, guide, data)
# loss.backward()
# optimizer.step()