Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro Model Inspect

From Leeroopedia


Overview

The inspect module (Template:Code) provides tools for analyzing the dependency structure of Pyro probabilistic models. It can infer prior and posterior dependencies among random variables, extract model relations (sample-to-sample, sample-to-distribution, plate-to-sample), and render models as graphical visualizations using Graphviz.

The two primary functions are:

  • get_dependencies -- Infers both prior and posterior dependency structure by tracing the model with provenance tracking. Returns nested dictionaries describing which upstream variables each downstream variable depends on, and which plates induce full dependencies.
  • get_model_relations -- Extracts structural relations including sample-sample dependencies, sample-distribution mappings, plate-sample memberships, parameter constraints, and observed variables.

The module also provides render_model, a high-level function that combines Template:Code, Template:Code, and Template:Code to produce a Graphviz visualization of a Pyro model's plate diagram.

Internally, the module uses ProvenanceTensor tracking via the TrackProvenance messenger to determine which upstream sample sites influence each downstream site's log-probability.

Code Reference

File: Template:Code

Key Functions

Function Description
Template:Code Infers prior and posterior dependency structure. Returns a dict with Template:Code and Template:Code.
Template:Code Infers relations of random variables and plates. Returns a dict with Template:Code, Template:Code, Template:Code, Template:Code, Template:Code, and Template:Code.
Template:Code Renders a model as a Graphviz diagram. Saves to file if filename provided.
Template:Code Converts model relations into a graph specification dict with plate groups, node data, and edge lists.
Template:Code Creates a Template:Code object from a graph specification.
Template:Code Helper that determines whether a trace message represents a valid sample site.

Key Classes

Class Description
Template:Code A Messenger that annotates sample and param values with provenance information using Template:Code. Used internally by Template:Code and Template:Code.

I/O Contract

get_dependencies

Inputs:

Output: A dictionary with two keys:

  • Template:Code -- Dict mapping each downstream variable to a dict of upstream latent variables it depends on, each mapped to a set of plates inducing full dependencies.
  • Template:Code -- Dict mapping each latent variable to a dict of latent/observed variables it depends on in the posterior, with associated plate sets.

get_model_relations

Inputs:

Output: A dictionary with keys:

  • Template:Code -- Dict mapping each sample site to a list of upstream sample sites.
  • Template:Code -- Dict mapping each sample site to a list of upstream param sites.
  • Template:Code -- Dict mapping each sample site to its distribution name (string).
  • Template:Code -- Dict mapping each param name to its constraint string.
  • Template:Code -- Dict mapping each plate name to a list of contained sample sites.
  • Template:Code -- List of observed sample site names.

render_model

Inputs:

Output:

Usage Examples

Inferring Model Dependencies

import pyro
import pyro.distributions as dist
from pyro.infer.inspect import get_dependencies

def model():
    a = pyro.sample("a", dist.Normal(0, 1))
    b = pyro.sample("b", dist.LogNormal(0, 1))
    c = pyro.sample("c", dist.Normal(a, b))
    pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(0.0))

deps = get_dependencies(model)
print(deps["prior_dependencies"])
# {'a': {'a': set()}, 'b': {'b': set()},
#  'c': {'a': set(), 'b': set(), 'c': set()},
#  'd': {'c': set(), 'd': set()}}

print(deps["posterior_dependencies"])
# {'a': {'a': set(), 'b': set(), 'c': set()},
#  'b': {'b': set(), 'c': set()},
#  'c': {'c': set(), 'd': set()}}

Extracting Model Relations

from pyro.infer.inspect import get_model_relations

def model(data):
    m = pyro.sample("m", dist.Normal(0, 1))
    sd = pyro.sample("sd", dist.LogNormal(m, 1))
    with pyro.plate("N", len(data)):
        pyro.sample("obs", dist.Normal(m, sd), obs=data)

relations = get_model_relations(model, model_args=(torch.randn(10),))
print(relations["sample_sample"])
# {'m': [], 'sd': ['m'], 'obs': ['m', 'sd']}
print(relations["plate_sample"])
# {'N': ['obs']}

Rendering a Model

from pyro.infer.inspect import render_model

graph = render_model(model, model_args=(torch.randn(10),),
                     filename="model.png",
                     render_distributions=True)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment