Implementation:Pyro ppl Pyro Model Inspect
Overview
The inspect module (Template:Code) provides tools for analyzing the dependency structure of Pyro probabilistic models. It can infer prior and posterior dependencies among random variables, extract model relations (sample-to-sample, sample-to-distribution, plate-to-sample), and render models as graphical visualizations using Graphviz.
The two primary functions are:
- get_dependencies -- Infers both prior and posterior dependency structure by tracing the model with provenance tracking. Returns nested dictionaries describing which upstream variables each downstream variable depends on, and which plates induce full dependencies.
- get_model_relations -- Extracts structural relations including sample-sample dependencies, sample-distribution mappings, plate-sample memberships, parameter constraints, and observed variables.
The module also provides render_model, a high-level function that combines Template:Code, Template:Code, and Template:Code to produce a Graphviz visualization of a Pyro model's plate diagram.
Internally, the module uses ProvenanceTensor tracking via the TrackProvenance messenger to determine which upstream sample sites influence each downstream site's log-probability.
Code Reference
File: Template:Code
Key Functions
| Function | Description |
|---|---|
| Template:Code | Infers prior and posterior dependency structure. Returns a dict with Template:Code and Template:Code. |
| Template:Code | Infers relations of random variables and plates. Returns a dict with Template:Code, Template:Code, Template:Code, Template:Code, Template:Code, and Template:Code. |
| Template:Code | Renders a model as a Graphviz diagram. Saves to file if filename provided. |
| Template:Code | Converts model relations into a graph specification dict with plate groups, node data, and edge lists. |
| Template:Code | Creates a Template:Code object from a graph specification. |
| Template:Code | Helper that determines whether a trace message represents a valid sample site. |
Key Classes
| Class | Description |
|---|---|
| Template:Code | A Messenger that annotates sample and param values with provenance information using Template:Code. Used internally by Template:Code and Template:Code. |
I/O Contract
get_dependencies
Inputs:
- Template:Code -- A Pyro model function.
- Template:Code -- Positional arguments for the model (default Template:Code).
- Template:Code -- Keyword arguments for the model (default Template:Code}).
- Template:Code -- Whether to include deterministic sites (default Template:Code).
Output: A dictionary with two keys:
- Template:Code -- Dict mapping each downstream variable to a dict of upstream latent variables it depends on, each mapped to a set of plates inducing full dependencies.
- Template:Code -- Dict mapping each latent variable to a dict of latent/observed variables it depends on in the posterior, with associated plate sets.
get_model_relations
Inputs:
- Same as Template:Code.
Output: A dictionary with keys:
- Template:Code -- Dict mapping each sample site to a list of upstream sample sites.
- Template:Code -- Dict mapping each sample site to a list of upstream param sites.
- Template:Code -- Dict mapping each sample site to its distribution name (string).
- Template:Code -- Dict mapping each param name to its constraint string.
- Template:Code -- Dict mapping each plate name to a list of contained sample sites.
- Template:Code -- List of observed sample site names.
render_model
Inputs:
- Template:Code -- A Pyro model.
- Template:Code -- Tuple or list of tuples (for semisupervised models).
- Template:Code -- Dict or list of dicts.
- Template:Code -- Output file path. If provided, saves the rendered image.
- Template:Code -- Whether to annotate distributions (default Template:Code).
- Template:Code -- Whether to show parameters (default Template:Code).
- Template:Code -- Whether to include deterministic sites (default Template:Code).
Output:
- Template:Code -- A Graphviz directed graph object.
Usage Examples
Inferring Model Dependencies
import pyro
import pyro.distributions as dist
from pyro.infer.inspect import get_dependencies
def model():
a = pyro.sample("a", dist.Normal(0, 1))
b = pyro.sample("b", dist.LogNormal(0, 1))
c = pyro.sample("c", dist.Normal(a, b))
pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(0.0))
deps = get_dependencies(model)
print(deps["prior_dependencies"])
# {'a': {'a': set()}, 'b': {'b': set()},
# 'c': {'a': set(), 'b': set(), 'c': set()},
# 'd': {'c': set(), 'd': set()}}
print(deps["posterior_dependencies"])
# {'a': {'a': set(), 'b': set(), 'c': set()},
# 'b': {'b': set(), 'c': set()},
# 'c': {'c': set(), 'd': set()}}
Extracting Model Relations
from pyro.infer.inspect import get_model_relations
def model(data):
m = pyro.sample("m", dist.Normal(0, 1))
sd = pyro.sample("sd", dist.LogNormal(m, 1))
with pyro.plate("N", len(data)):
pyro.sample("obs", dist.Normal(m, sd), obs=data)
relations = get_model_relations(model, model_args=(torch.randn(10),))
print(relations["sample_sample"])
# {'m': [], 'sd': ['m'], 'obs': ['m', 'sd']}
print(relations["plate_sample"])
# {'N': ['obs']}
Rendering a Model
from pyro.infer.inspect import render_model
graph = render_model(model, model_args=(torch.randn(10),),
filename="model.png",
render_distributions=True)
Related Pages
- Pyro_ppl_Pyro_Abstract_Infer -- Base classes for inference that these inspection tools analyze
- Pyro_ppl_Pyro_Infer_Utilities -- Utility functions used throughout the inference subsystem
- Pyro_ppl_Pyro_TraceGraph_ELBO -- Uses dependency information for variance reduction