Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro Trace Struct

From Leeroopedia
Revision as of 16:26, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Pyro_ppl_Pyro_Trace_Struct.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Attribute Value
File pyro/poutine/trace_struct.py
Module pyro.poutine.trace_struct
Lines 581
Purpose Directed graph data structure for recording Pyro program execution traces
Architecture Role Core data structure used by TraceMessenger and all inference algorithms
License Apache-2.0 (Uber Technologies, Inc.)

Overview

trace_struct.py defines the Trace class, a directed graph data structure that records every call to pyro.sample() and pyro.param() during a single execution of a Pyro program. Traces are the primary data exchange format between models, guides, and inference algorithms.

The Trace class is implemented as a lightweight directed graph with:

  • Nodes -- An OrderedDict mapping site names to Message dicts (containing site type, distribution, value, log probability, scale, mask, and conditional independence metadata).
  • Edges -- Adjacency lists (_succ and _pred) encoding conditional dependence relationships.
  • Graph types -- Either "flat" (no dependency tracking) or "dense" (full dependency edges).

Key capabilities include:

  • Computing log_prob_sum() -- the total log probability across all sample sites.
  • Computing per-site log_prob and score_parts for gradient estimation.
  • Topological sorting of nodes for dependency-aware processing.
  • Shape formatting via format_shapes() for debugging.
  • Tensor packing via pack_tensors() and symbolize_dims() for efficient enumeration.
  • Convenient properties for accessing observation_nodes, stochastic_nodes, param_nodes, and reparameterized_nodes.

Code Reference

Trace Class

class Trace:
    def __init__(self, graph_type: Literal["flat", "dense"] = "flat") -> None:
        self.graph_type = graph_type
        self.nodes: OrderedDict[str, "Message"] = OrderedDict()
        self._succ: OrderedDict[str, Set[str]] = OrderedDict()
        self._pred: OrderedDict[str, Set[str]] = OrderedDict()

Graph Operations

def add_node(self, site_name: str, **kwargs) -> None:
    """Adds a site to the trace. Raises error on duplicates."""

def add_edge(self, site1: str, site2: str) -> None:
    """Adds a directed edge between two sites."""

def remove_node(self, site_name: str) -> None:
    """Removes a site and its associated edges."""

def topological_sort(self, reverse: bool = False) -> List[str]:
    """Return nodes in topologically sorted order."""

def copy(self) -> "Trace":
    """Makes a shallow copy with nodes and edges preserved."""

Log Probability Computation

def log_prob_sum(self, site_filter=allow_all_sites):
    """
    Compute total log probability across all sample sites.
    Each log_prob_sum is a scalar. Computation is memoized.
    """
    result = 0.0
    for name, site in self.nodes.items():
        if site["type"] == "sample" and site_filter(name, site):
            if "log_prob_sum" in site:
                log_p = site["log_prob_sum"]
            else:
                log_p = site["fn"].log_prob(site["value"], ...)
                log_p = scale_and_mask(log_p, site["scale"], site["mask"]).sum()
                site["log_prob_sum"] = log_p
            result = result + log_p
    return result

def compute_log_prob(self, site_filter=allow_all_sites) -> None:
    """Compute per-site log_prob (batched). Memoized."""

def compute_score_parts(self) -> None:
    """Compute batched local score parts for gradient estimation."""

Node Properties

@property
def observation_nodes(self) -> List[str]:
    """List of names of observed sample sites."""

@property
def stochastic_nodes(self) -> List[str]:
    """List of names of unobserved sample sites."""

@property
def param_nodes(self) -> List[str]:
    """List of names of param sites."""

@property
def reparameterized_nodes(self) -> List[str]:
    """List of names of sample sites with reparameterizable distributions."""

@property
def nonreparam_stochastic_nodes(self) -> List[str]:
    """List of names of non-reparameterizable sample sites."""

Tensor Packing and Shape Formatting

def symbolize_dims(self, plate_to_symbol=None) -> None:
    """Assign unique symbols to all tensor dimensions."""

def pack_tensors(self, plate_to_symbol=None) -> None:
    """Compute packed representations for efficient enumeration."""

def format_shapes(self, title="Trace Shapes:", last_site=None) -> str:
    """Return a formatted table of shapes at all sites (for debugging)."""

I/O Contract

Method Input Output
__init__(graph_type) "flat" or "dense" A new empty Trace instance
add_node(name, **kwargs) Site name and Message fields as keyword arguments None (mutates trace in place)
log_prob_sum(site_filter) Optional site filter function (name, site) -> bool torch.Tensor scalar (total log probability)
compute_log_prob(site_filter) Optional site filter function None (populates log_prob and log_prob_sum fields in each node)
compute_score_parts() (none) None (populates score_parts, log_prob, log_prob_sum in each node)
topological_sort(reverse) Optional bool for reverse order List[str] of site names in topological order
format_shapes(title, last_site) Optional title string and last site name A formatted str table showing shapes at all sites
copy() (none) A shallow copy of the Trace
detach_() (none) None (detaches values in place at each sample site)

Usage Examples

Creating and Inspecting a Trace

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine

def model(x):
    s = pyro.param("s", torch.tensor(0.5))
    z = pyro.sample("z", dist.Normal(x, s))
    return z ** 2

trace = poutine.trace(model).get_trace(0.0)
logp = trace.log_prob_sum()
params = [trace.nodes[name]["value"].unconstrained()
          for name in trace.param_nodes]

Inspecting Trace Nodes

# Node names
list(trace.nodes.keys())  # ["_INPUT", "s", "z", "_RETURN"]

# Node metadata
trace.nodes["z"]
# {'type': 'sample', 'name': 'z', 'is_observed': False,
#  'fn': Normal(), 'value': tensor(0.6480), ...}

# Categorized node lists
trace.stochastic_nodes    # ["z"]
trace.param_nodes         # ["s"]
trace.observation_nodes   # []

Debugging with format_shapes

print(trace.format_shapes())
# Trace Shapes:
#  Param Sites:
#            s
#  Sample Sites:
#         z dist |
#          value |

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment