Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Principle:Pyro ppl Pyro Minimal PPL

From Leeroopedia


Knowledge Sources
Domains Probabilistic Programming, Language Design, Education
Last Updated 2026-02-09 09:00 GMT

Overview

A minimal probabilistic programming language distills the core abstractions of a PPL -- sample, observe, effect handlers, and inference -- into a small, self-contained implementation that serves as both educational material and a reference specification.

Description

Full probabilistic programming systems like Pyro involve thousands of lines of code handling edge cases, performance optimizations, and feature interactions. A minimal PPL strips away all of this complexity to reveal the essential design patterns.

The core insight is that a PPL needs only a few primitives:

sample(name, distribution): Draw a value from a probability distribution and record it. This is the fundamental stochastic primitive.

param(name, initial_value): Declare a learnable parameter. During inference, these parameters are optimized.

plate(name, size): Declare a conditional independence context for vectorized computation.

Effect handler stack: A global stack of context managers that intercept sample and param calls. Each handler can modify the behavior (e.g., condition on observed values, record a trace, substitute values).

Trace handler: Records all sample sites and their values, distributions, and log-probabilities.

SVI (Stochastic Variational Inference): The simplest inference algorithm, which optimizes ELBO using stochastic gradients.

A minimal PPL demonstrates that the full power of probabilistic programming emerges from the interaction of these simple components. The effect handler pattern is the key architectural insight: by changing what happens when sample is called, the same model function can be used for prior sampling, posterior inference, conditioning, and prediction.

Usage

Use the minimal PPL when:

  • Learning how probabilistic programming languages work internally.
  • Teaching the core concepts of PPLs without the complexity of a full system.
  • Prototyping new inference algorithms or language features.
  • Understanding Pyro's architecture by studying its simplified core.
  • Building a custom PPL for a specialized domain.

Theoretical Basis

Core primitives:

# The entire PPL is built from these primitives:

# 1. sample: stochastic choice
def sample(name, dist, obs=None):
    # Apply all active effect handlers
    value = apply_stack(msg={
        "type": "sample",
        "name": name,
        "fn": dist,
        "value": obs,
    })
    return value

# 2. param: learnable parameter
def param(name, initial_value):
    value = apply_stack(msg={
        "type": "param",
        "name": name,
        "value": initial_value,
    })
    return value

# 3. apply_stack: the effect handler mechanism
HANDLER_STACK = []

def apply_stack(msg):
    # Walk handlers from top to bottom (process)
    for handler in reversed(HANDLER_STACK):
        handler.process_message(msg)
        if msg.get("stop"):
            break
    # Apply default behavior
    if msg["value"] is None:
        msg["value"] = msg["fn"].sample()
    # Walk handlers from bottom to top (postprocess)
    for handler in HANDLER_STACK:
        handler.postprocess_message(msg)
    return msg["value"]

Effect handlers:

# Base handler (context manager that pushes/pops from stack):
class Handler:
    def __enter__(self):
        HANDLER_STACK.append(self)
        return self
    def __exit__(self, *args):
        HANDLER_STACK.pop()
    def process_message(self, msg):
        pass
    def postprocess_message(self, msg):
        pass

# Trace handler: records all sites
class Trace(Handler):
    def __init__(self):
        self.trace = {}
    def postprocess_message(self, msg):
        self.trace[msg["name"]] = msg.copy()
        msg["log_prob"] = msg["fn"].log_prob(msg["value"])

# Condition handler: fix observed values
class Condition(Handler):
    def __init__(self, data):
        self.data = data
    def process_message(self, msg):
        if msg["name"] in self.data:
            msg["value"] = self.data[msg["name"]]

Minimal SVI:

# ELBO estimation:
def elbo(model, guide, *args):
    # Trace the guide
    with Trace() as guide_trace:
        guide(*args)
    # Replay model with guide values, trace model
    with Trace() as model_trace:
        with Replay(guide_trace.trace):
            model(*args)
    # Compute ELBO
    loss = 0
    for name, site in model_trace.trace.items():
        loss = loss + site["log_prob"]
    for name, site in guide_trace.trace.items():
        if site["type"] == "sample":
            loss = loss - site["log_prob"]
    return -loss  # negate for minimization

# SVI loop:
# for step in range(num_steps):
#     loss = elbo(model, guide, data)
#     loss.backward()
#     optimizer.step()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment