Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro Autoname TreeData

From Leeroopedia


Property Value
Implementation Type Pattern Doc
Source File examples/contrib/autoname/tree_data.py
Module pyro.contrib.autoname
Pyro Features pyro.contrib.autoname.named, named.Object, named.List, named.Dict, SVI, Trace_ELBO, recursive model/guide
Pattern Linear mixed-effects model over arbitrary JSON-like tree-structured data

Overview

This file demonstrates a linear mixed-effects model over arbitrary JSON-like hierarchical data using pyro.contrib.autoname. The data can be a number (tensor), a list of data, or a dictionary with data values, and the model recursively walks this tree structure.

The key pattern is using named.Object, named.List, and named.Dict to automatically generate unique names for latent variables at each level of the tree. This enables modular, recursive model and guide functions that naturally match the tree structure of the data.

At each node in the tree:

  • A latent variable z is sampled from Normal(parent_z, prior_scale)
  • At leaf nodes, the observation is sampled from Normal(z, 1.0)
  • The guide learns mean-field Normal approximations with post_loc and post_scale parameters

Code Reference

def model(data):
    latent = named.Object("latent")
    latent.z.sample_(dist.Normal(0.0, 1.0))
    model_recurse(data, latent)

def model_recurse(data, latent):
    if torch.is_tensor(data):
        latent.x.sample_(dist.Normal(latent.z, 1.0), obs=data)
    elif isinstance(data, list):
        latent.prior_scale.param_(torch.tensor(1.0), constraint=constraints.positive)
        latent.list = named.List()
        for data_i in data:
            latent_i = latent.list.add()
            latent_i.z.sample_(dist.Normal(latent.z, latent.prior_scale))
            model_recurse(data_i, latent_i)
    elif isinstance(data, dict):
        latent.prior_scale.param_(torch.tensor(1.0), constraint=constraints.positive)
        latent.dict = named.Dict()
        for key, value in data.items():
            latent.dict[key].z.sample_(dist.Normal(latent.z, latent.prior_scale))
            model_recurse(value, latent.dict[key])

I/O Contract

Parameter Type Description
data nested dict/list/tensor JSON-like hierarchical data with tensors at leaves
-n / --num-epochs int Number of training epochs (default: 100)

Example data structure:

data = {
    "foo": one,                      # scalar tensor
    "bar": [0*one, 1*one, 2*one],   # list of tensors
    "baz": {
        "noun": {"concrete": 4*one, "abstract": 6*one},
        "verb": 2*one,
    },
}

Output:

  • Learned posterior locations and scales for each latent z in the hierarchy
  • Learned prior scale parameters at each branching node

Usage Examples

import pyro
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

pyro.set_rng_seed(0)
optim = Adam({"lr": 0.1})
inference = SVI(model, guide, optim, loss=Trace_ELBO())

one = torch.tensor(1.0)
data = {"foo": one, "bar": [0*one, 1*one, 2*one]}

for step in range(100):
    loss = inference.step(data)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment