Implementation:Pyro ppl Pyro Autoname Scoping
| Property | Value |
|---|---|
| Implementation Type | Pattern Doc |
| Source File | examples/contrib/autoname/scoping_mixture.py
|
| Module | pyro.contrib.autoname |
| Pyro Features | pyro.contrib.autoname.scope, config_enumerate, TraceEnum_ELBO, pyro.plate, discrete enumeration
|
| Pattern | Gaussian Mixture Model using scope decorator for site name prefixing |
Overview
This file demonstrates using the @scope decorator from pyro.contrib.autoname to build a Gaussian Mixture Model with modular local model/guide functions. The @scope(prefix="local") decorator automatically prefixes all pyro.sample and pyro.param site names within the decorated function, enabling clean separation of global and local model components without manual name management.
Unlike the named.Object approach in the companion mixture example, the scoping approach works with discrete enumeration via config_enumerate and TraceEnum_ELBO, allowing exact marginalization of the discrete mixture assignments.
Code Reference
def model(K, data):
weights = pyro.param("weights", torch.ones(K) / K, constraint=constraints.simplex)
locs = pyro.param("locs", 10 * torch.randn(K))
scale = pyro.param("scale", torch.tensor(0.5), constraint=constraints.positive)
with pyro.plate("data"):
return local_model(weights, locs, scale, data)
@scope(prefix="local")
def local_model(weights, locs, scale, data):
assignment = pyro.sample("assignment",
dist.Categorical(weights).expand_by([len(data)]))
return pyro.sample("obs", dist.Normal(locs[assignment], scale), obs=data)
def guide(K, data):
assignment_probs = pyro.param("assignment_probs",
torch.ones(len(data), K) / K, constraint=constraints.unit_interval)
with pyro.plate("data"):
return local_guide(assignment_probs)
@scope(prefix="local")
def local_guide(probs):
return pyro.sample("assignment", dist.Categorical(probs))
I/O Contract
| Parameter | Type | Description |
|---|---|---|
K |
int |
Number of mixture components |
data |
torch.Tensor |
1D tensor of observed data points |
-n / --num-epochs |
int |
Number of training epochs (default: 200) |
Named sites (after scoping):
local/assignment: Categorical component assignment (auto-prefixed)local/obs: Normal observation (auto-prefixed)weights,locs,scale: Global parametersassignment_probs: Guide parameters
Usage Examples
import pyro
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate
pyro.set_rng_seed(0)
pyro.clear_param_store()
K = 2
data = torch.tensor([0.0, 1.0, 2.0, 20.0, 30.0, 40.0])
optim = pyro.optim.Adam({"lr": 0.1})
inference = SVI(model, config_enumerate(guide), optim,
loss=TraceEnum_ELBO(max_plate_nesting=1))
for step in range(200):
loss = inference.step(K, data)
Related Pages
- Pyro_ppl_Pyro_Autoname_Mixture - Mixture model using
named.Objectapproach - Pyro_ppl_Pyro_Autoname_TreeData - Hierarchical model using named objects