Implementation:Pyro ppl Pyro MiniPyro Example
| Property | Value |
|---|---|
| Implementation Type | Pattern Doc |
| Source File | examples/minipyro.py
|
| Module | pyro.contrib.minipyro |
| Pyro Features | pyro.generic interface, pyro_backend, SVI, Trace_ELBO, JitTrace_ELBO, backend switching
|
| Pattern | Demonstrates MiniPyro -- a minimal PPL for didactic purposes |
Overview
This file demonstrates pyro.contrib.minipyro, a minimal implementation of the Pyro Probabilistic Programming Language created for educational purposes. The example shows that MiniPyro's API is compatible with the full Pyro API, so the same model, guide, and training code works with both backends.
The example uses the pyro.generic interface, which provides a dynamic backend switching mechanism via pyro_backend(). This means you can write probabilistic programs that work with:
"minipyro": The minimal backend (for learning/teaching)"pyro": The full Pyro backend (for production)
The model is a simple conjugate Gaussian: a latent location loc ~ Normal(0, 1) observed through Normally distributed data obs ~ Normal(loc, 1). The guide is a diagonal Normal with learnable mean and scale.
Code Reference
from pyro.generic import distributions as dist
from pyro.generic import infer, ops, optim, pyro, pyro_backend
def main(args):
def model(data):
loc = pyro.sample("loc", dist.Normal(0.0, 1.0))
with pyro.plate("data", len(data), dim=-1):
pyro.sample("obs", dist.Normal(loc, 1.0), obs=data)
def guide(data):
guide_loc = pyro.param("guide_loc", torch.tensor(0.0))
guide_scale = ops.exp(pyro.param("guide_scale_log", torch.tensor(0.0)))
pyro.sample("loc", dist.Normal(guide_loc, guide_scale))
data = torch.randn(100) + 3.0
with pyro_backend(args.backend):
Elbo = infer.JitTrace_ELBO if args.jit else infer.Trace_ELBO
elbo = Elbo()
adam = optim.Adam({"lr": args.learning_rate})
svi = infer.SVI(model, guide, adam, elbo)
for step in range(args.num_steps):
loss = svi.step(data)
I/O Contract
| Parameter | Type | Description |
|---|---|---|
-b / --backend |
str |
Backend to use: "minipyro" or "pyro" (default: "minipyro") |
-n / --num-steps |
int |
SVI optimization steps (default: 1001) |
-lr / --learning-rate |
float |
Learning rate (default: 0.02) |
--jit |
flag | Use JIT-compiled ELBO |
Output:
- Loss at every 100 steps
- Final learned variational parameters (guide_loc should be ~3.0)
- Assertion that guide_loc is within 0.1 of 3.0
Usage Examples
# Run with MiniPyro backend (default)
# python minipyro.py -n 1001 -lr 0.02
# Run with full Pyro backend
# python minipyro.py -b pyro -n 1001
# Run with JIT
# python minipyro.py --jit -b pyro
Related Pages
- Pyro_ppl_Pyro_InclinedPlane - Another introductory inference example