Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro Autoname Mixture

From Leeroopedia


Property Value
Implementation Type Pattern Doc
Source File examples/contrib/autoname/mixture.py
Module pyro.contrib.autoname
Pyro Features pyro.contrib.autoname.named, named.Object, named.List, SVI, Trace_ELBO / JitTrace_ELBO
Pattern Gaussian Mixture Model with named objects for modular model/guide design

Overview

This file demonstrates how to use pyro.contrib.autoname.named to build a Gaussian Mixture Model with clean, modular code. The named.Object API automatically generates unique site names for Pyro sample and param statements, making it easy to pass latent variable containers between model components.

The key pattern shown is passing named.Objects() from a global model to a local model implemented as a helper function. This allows:

  • Global parameters (mixture weights, locations, scales) to be defined at the top level
  • Local latent variables (component assignments, observations) to be defined in a reusable helper
  • Automatic naming that avoids site name collisions

Code Reference

def model(data, k):
    latent = named.Object("latent")
    latent.probs.param_(torch.ones(k) / k, constraint=constraints.simplex)
    latent.locs.param_(torch.zeros(k))
    latent.scales.param_(torch.ones(k), constraint=constraints.positive)

    latent.local = named.List()
    for x in data:
        local_model(latent.local.add(), latent.probs, latent.locs, latent.scales, obs=x)

def local_model(latent, ps, locs, scales, obs=None):
    i = latent.id.sample_(dist.Categorical(ps))
    return latent.x.sample_(dist.Normal(locs[i], scales[i]), obs=obs)

def guide(data, k):
    latent = named.Object("latent")
    latent.local = named.List()
    for x in data:
        local_guide(latent.local.add(), k)

def local_guide(latent, k):
    latent.probs.param_(torch.ones(k) / k, constraint=constraints.positive)
    latent.id.sample_(dist.Categorical(latent.probs))

I/O Contract

Parameter Type Description
data torch.Tensor 1D tensor of observed data points
k int Number of mixture components
-n / --num-epochs int Number of training epochs (default: 200)
--jit flag Use JIT-compiled ELBO

Named parameters (auto-generated):

  • latent.probs: Mixture weights (simplex constraint)
  • latent.locs: Component means
  • latent.scales: Component standard deviations
  • latent.local.{i}.id: Component assignment for data point i
  • latent.local.{i}.x: Observation for data point i

Usage Examples

import pyro
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

pyro.set_rng_seed(0)
optim = Adam({"lr": 0.1})
inference = SVI(model, guide, optim, loss=Trace_ELBO())
data = torch.tensor([0.0, 1.0, 2.0, 20.0, 30.0, 40.0])

for step in range(200):
    loss = inference.step(data, k=2)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment