Implementation:Pyro ppl Pyro MiniPyro
Appearance
| Property | Value |
|---|---|
| Module | pyro.contrib.minipyro
|
| Source | pyro/contrib/minipyro.py |
| Lines | 405 |
| Classes | Messenger, trace, replay, block, seed, PlateMessenger, Adam, SVI, JitTrace_ELBO
|
| Functions | apply_stack, sample, param, plate, elbo, Trace_ELBO, get_param_store
|
| Dependencies | torch, pyro.distributions
|
Overview
Mini Pyro is a minimal, self-contained implementation of the Pyro probabilistic programming language contained in a single file. It replicates the core API of the full Pyro implementation (method signatures, etc.) while remaining independent of the rest of the Pyro codebase (except for pyro.distributions). It is designed as:
- A pedagogical tool for understanding how Pyro works internally.
- A reference implementation of the effect handler pattern used throughout Pyro.
- A testing baseline for verifying that the full implementation matches expected behavior.
Mini Pyro tracks two kinds of global state:
- Effect handler stack (
PYRO_STACK): Enables non-standard interpretations ofsample()andparam(). - Parameter store (
PARAM_STORE): Maps parameter names to(unconstrained_value, constraint)tuples.
Code Reference
Core Effect Handlers (Messengers)
Messenger: Base handler class. Pushes itself ontoPYRO_STACKon enter, pops on exit. Providesprocess_message(pre-processing) andpostprocess_message(post-processing) hooks.
trace: Records all sample/param site names, values, and distributions into anOrderedDict. Usespostprocess_messageto record values after all effects have been applied.
replay(fn, guide_trace): Replaces sample values with those from a previous trace. This enables computing the model's log probability at guide samples.
block(fn, hide_fn): Stops message propagation for sites wherehide_fn(msg)returns True. Used to nest inference inside models.
seed(fn, rng_seed): Fixes the RNG state for reproducible execution.
PlateMessenger: Limited plate implementation supporting batch shape broadcasting along a specified dimension.
Primitive Operations
apply_stack(msg): Core dispatch mechanism. Walks the handler stack in reverse, callingprocess_messageon each handler. If no handler sets a value, calls the distribution's sample method. Then callspostprocess_messagefrom bottom up.
sample(name, fn, *args, obs=None): Effectful sampling. Without handlers, draws from distributionfn. With handlers, creates a message and passes it throughapply_stack.
param(name, init_value, constraint, event_dim): Effectful parameter access. On first call, stores the constrained-to-unconstrained transform inPARAM_STORE. On subsequent calls, retrieves and transforms back to constrained space.
Inference
elbo(model, guide, *args): Computes the Evidence Lower Bound by: (1) tracing the guide, (2) tracing the model replayed with guide samples, (3) summinglog p(z)terms from model and subtractinglog q(z)terms from guide.
Trace_ELBO(**kwargs): Wrapper returningelbofor API compatibility with full Pyro.
SVI(model, guide, optim, loss): Stochastic Variational Inference. Thestepmethod traces parameters, computes loss, backpropagates, and updates.
Adam(optim_args): Dynamic optimizer that creates per-parametertorch.optim.Adaminstances.
JitTrace_ELBO: JIT-compiled ELBO that defers tracing to first invocation and registers parameters fortorch.jit.trace.
I/O Contract
| Function | Input | Output |
|---|---|---|
sample(name, fn, obs=None) |
str, Distribution, optional Tensor |
Tensor (sample or observation)
|
param(name, init_value, constraint) |
str, Tensor, Constraint |
Tensor (constrained parameter)
|
elbo(model, guide, *args) |
Callables and model args | Tensor (scalar loss)
|
SVI.step(*args) |
Model arguments | float (loss value)
|
trace.get_trace(*args) |
Model arguments | OrderedDict of site dicts
|
Usage Examples
import torch
from pyro.contrib.minipyro import (
sample, param, plate, SVI, Adam, Trace_ELBO,
trace, seed, get_param_store,
)
import pyro.distributions as dist
# Define a simple model
def model(data):
loc = sample("loc", dist.Normal(0.0, 1.0))
with plate("data", data.shape[0], dim=-1):
sample("obs", dist.Normal(loc, 1.0), obs=data)
# Define a guide
def guide(data):
loc_q = param("loc_q", torch.tensor(0.0))
scale_q = param("scale_q", torch.tensor(1.0),
constraint=torch.distributions.constraints.positive)
sample("loc", dist.Normal(loc_q, scale_q))
# Run SVI
data = torch.randn(100) + 3.0
optimizer = Adam({"lr": 0.01})
svi = SVI(model, guide, optimizer, Trace_ELBO())
for step in range(1000):
loss = svi.step(data)
if step % 100 == 0:
print(f"Step {step}: loss = {loss:.4f}")
# Inspect parameters
store = get_param_store()
print(f"loc_q: {param('loc_q').item():.2f}")
# Trace execution with fixed seed
with seed(rng_seed=42):
tr = trace(model).get_trace(data)
for name, site in tr.items():
if site["type"] == "sample":
print(f"{name}: {site['value'].shape}")
Related Pages
- Pyro_ppl_Pyro_Util -- Full Pyro utility functions
- Pyro_ppl_Pyro_LazyJIT -- Full JIT support for Pyro
- Pyro_ppl_Pyro_Settings -- Global settings management
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment