Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro MiniPyro

From Leeroopedia


Property Value
Module pyro.contrib.minipyro
Source pyro/contrib/minipyro.py
Lines 405
Classes Messenger, trace, replay, block, seed, PlateMessenger, Adam, SVI, JitTrace_ELBO
Functions apply_stack, sample, param, plate, elbo, Trace_ELBO, get_param_store
Dependencies torch, pyro.distributions

Overview

Mini Pyro is a minimal, self-contained implementation of the Pyro probabilistic programming language contained in a single file. It replicates the core API of the full Pyro implementation (method signatures, etc.) while remaining independent of the rest of the Pyro codebase (except for pyro.distributions). It is designed as:

  • A pedagogical tool for understanding how Pyro works internally.
  • A reference implementation of the effect handler pattern used throughout Pyro.
  • A testing baseline for verifying that the full implementation matches expected behavior.

Mini Pyro tracks two kinds of global state:

  1. Effect handler stack (PYRO_STACK): Enables non-standard interpretations of sample() and param().
  2. Parameter store (PARAM_STORE): Maps parameter names to (unconstrained_value, constraint) tuples.

Code Reference

Core Effect Handlers (Messengers)

  • Messenger: Base handler class. Pushes itself onto PYRO_STACK on enter, pops on exit. Provides process_message (pre-processing) and postprocess_message (post-processing) hooks.
  • trace: Records all sample/param site names, values, and distributions into an OrderedDict. Uses postprocess_message to record values after all effects have been applied.
  • replay(fn, guide_trace): Replaces sample values with those from a previous trace. This enables computing the model's log probability at guide samples.
  • block(fn, hide_fn): Stops message propagation for sites where hide_fn(msg) returns True. Used to nest inference inside models.
  • seed(fn, rng_seed): Fixes the RNG state for reproducible execution.
  • PlateMessenger: Limited plate implementation supporting batch shape broadcasting along a specified dimension.

Primitive Operations

  • apply_stack(msg): Core dispatch mechanism. Walks the handler stack in reverse, calling process_message on each handler. If no handler sets a value, calls the distribution's sample method. Then calls postprocess_message from bottom up.
  • sample(name, fn, *args, obs=None): Effectful sampling. Without handlers, draws from distribution fn. With handlers, creates a message and passes it through apply_stack.
  • param(name, init_value, constraint, event_dim): Effectful parameter access. On first call, stores the constrained-to-unconstrained transform in PARAM_STORE. On subsequent calls, retrieves and transforms back to constrained space.

Inference

  • elbo(model, guide, *args): Computes the Evidence Lower Bound by: (1) tracing the guide, (2) tracing the model replayed with guide samples, (3) summing log p(z) terms from model and subtracting log q(z) terms from guide.
  • Trace_ELBO(**kwargs): Wrapper returning elbo for API compatibility with full Pyro.
  • SVI(model, guide, optim, loss): Stochastic Variational Inference. The step method traces parameters, computes loss, backpropagates, and updates.
  • Adam(optim_args): Dynamic optimizer that creates per-parameter torch.optim.Adam instances.
  • JitTrace_ELBO: JIT-compiled ELBO that defers tracing to first invocation and registers parameters for torch.jit.trace.

I/O Contract

Function Input Output
sample(name, fn, obs=None) str, Distribution, optional Tensor Tensor (sample or observation)
param(name, init_value, constraint) str, Tensor, Constraint Tensor (constrained parameter)
elbo(model, guide, *args) Callables and model args Tensor (scalar loss)
SVI.step(*args) Model arguments float (loss value)
trace.get_trace(*args) Model arguments OrderedDict of site dicts

Usage Examples

import torch
from pyro.contrib.minipyro import (
    sample, param, plate, SVI, Adam, Trace_ELBO,
    trace, seed, get_param_store,
)
import pyro.distributions as dist

# Define a simple model
def model(data):
    loc = sample("loc", dist.Normal(0.0, 1.0))
    with plate("data", data.shape[0], dim=-1):
        sample("obs", dist.Normal(loc, 1.0), obs=data)

# Define a guide
def guide(data):
    loc_q = param("loc_q", torch.tensor(0.0))
    scale_q = param("scale_q", torch.tensor(1.0),
                     constraint=torch.distributions.constraints.positive)
    sample("loc", dist.Normal(loc_q, scale_q))

# Run SVI
data = torch.randn(100) + 3.0
optimizer = Adam({"lr": 0.01})
svi = SVI(model, guide, optimizer, Trace_ELBO())

for step in range(1000):
    loss = svi.step(data)
    if step % 100 == 0:
        print(f"Step {step}: loss = {loss:.4f}")

# Inspect parameters
store = get_param_store()
print(f"loc_q: {param('loc_q').item():.2f}")

# Trace execution with fixed seed
with seed(rng_seed=42):
    tr = trace(model).get_trace(data)
    for name, site in tr.items():
        if site["type"] == "sample":
            print(f"{name}: {site['value'].shape}")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment