Implementation:Pyro ppl Pyro Autoname TreeData
| Property | Value |
|---|---|
| Implementation Type | Pattern Doc |
| Source File | examples/contrib/autoname/tree_data.py
|
| Module | pyro.contrib.autoname |
| Pyro Features | pyro.contrib.autoname.named, named.Object, named.List, named.Dict, SVI, Trace_ELBO, recursive model/guide
|
| Pattern | Linear mixed-effects model over arbitrary JSON-like tree-structured data |
Overview
This file demonstrates a linear mixed-effects model over arbitrary JSON-like hierarchical data using pyro.contrib.autoname. The data can be a number (tensor), a list of data, or a dictionary with data values, and the model recursively walks this tree structure.
The key pattern is using named.Object, named.List, and named.Dict to automatically generate unique names for latent variables at each level of the tree. This enables modular, recursive model and guide functions that naturally match the tree structure of the data.
At each node in the tree:
- A latent variable
zis sampled fromNormal(parent_z, prior_scale) - At leaf nodes, the observation is sampled from
Normal(z, 1.0) - The guide learns mean-field Normal approximations with
post_locandpost_scaleparameters
Code Reference
def model(data):
latent = named.Object("latent")
latent.z.sample_(dist.Normal(0.0, 1.0))
model_recurse(data, latent)
def model_recurse(data, latent):
if torch.is_tensor(data):
latent.x.sample_(dist.Normal(latent.z, 1.0), obs=data)
elif isinstance(data, list):
latent.prior_scale.param_(torch.tensor(1.0), constraint=constraints.positive)
latent.list = named.List()
for data_i in data:
latent_i = latent.list.add()
latent_i.z.sample_(dist.Normal(latent.z, latent.prior_scale))
model_recurse(data_i, latent_i)
elif isinstance(data, dict):
latent.prior_scale.param_(torch.tensor(1.0), constraint=constraints.positive)
latent.dict = named.Dict()
for key, value in data.items():
latent.dict[key].z.sample_(dist.Normal(latent.z, latent.prior_scale))
model_recurse(value, latent.dict[key])
I/O Contract
| Parameter | Type | Description |
|---|---|---|
data |
nested dict/list/tensor | JSON-like hierarchical data with tensors at leaves |
-n / --num-epochs |
int |
Number of training epochs (default: 100) |
Example data structure:
data = {
"foo": one, # scalar tensor
"bar": [0*one, 1*one, 2*one], # list of tensors
"baz": {
"noun": {"concrete": 4*one, "abstract": 6*one},
"verb": 2*one,
},
}
Output:
- Learned posterior locations and scales for each latent
zin the hierarchy - Learned prior scale parameters at each branching node
Usage Examples
import pyro
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
pyro.set_rng_seed(0)
optim = Adam({"lr": 0.1})
inference = SVI(model, guide, optim, loss=Trace_ELBO())
one = torch.tensor(1.0)
data = {"foo": one, "bar": [0*one, 1*one, 2*one]}
for step in range(100):
loss = inference.step(data)
Related Pages
- Pyro_ppl_Pyro_Autoname_Mixture - Mixture model using named objects
- Pyro_ppl_Pyro_Autoname_Scoping - Mixture model using scope decorator