Implementation:Pyro ppl Pyro Trace Struct
Appearance
| Attribute | Value |
|---|---|
| File | pyro/poutine/trace_struct.py
|
| Module | pyro.poutine.trace_struct
|
| Lines | 581 |
| Purpose | Directed graph data structure for recording Pyro program execution traces |
| Architecture Role | Core data structure used by TraceMessenger and all inference algorithms |
| License | Apache-2.0 (Uber Technologies, Inc.) |
Overview
trace_struct.py defines the Trace class, a directed graph data structure that records every call to pyro.sample() and pyro.param() during a single execution of a Pyro program. Traces are the primary data exchange format between models, guides, and inference algorithms.
The Trace class is implemented as a lightweight directed graph with:
- Nodes -- An
OrderedDictmapping site names toMessagedicts (containing site type, distribution, value, log probability, scale, mask, and conditional independence metadata). - Edges -- Adjacency lists (
_succand_pred) encoding conditional dependence relationships. - Graph types -- Either
"flat"(no dependency tracking) or"dense"(full dependency edges).
Key capabilities include:
- Computing
log_prob_sum()-- the total log probability across all sample sites. - Computing per-site
log_probandscore_partsfor gradient estimation. - Topological sorting of nodes for dependency-aware processing.
- Shape formatting via
format_shapes()for debugging. - Tensor packing via
pack_tensors()andsymbolize_dims()for efficient enumeration. - Convenient properties for accessing
observation_nodes,stochastic_nodes,param_nodes, andreparameterized_nodes.
Code Reference
Trace Class
class Trace:
def __init__(self, graph_type: Literal["flat", "dense"] = "flat") -> None:
self.graph_type = graph_type
self.nodes: OrderedDict[str, "Message"] = OrderedDict()
self._succ: OrderedDict[str, Set[str]] = OrderedDict()
self._pred: OrderedDict[str, Set[str]] = OrderedDict()
Graph Operations
def add_node(self, site_name: str, **kwargs) -> None:
"""Adds a site to the trace. Raises error on duplicates."""
def add_edge(self, site1: str, site2: str) -> None:
"""Adds a directed edge between two sites."""
def remove_node(self, site_name: str) -> None:
"""Removes a site and its associated edges."""
def topological_sort(self, reverse: bool = False) -> List[str]:
"""Return nodes in topologically sorted order."""
def copy(self) -> "Trace":
"""Makes a shallow copy with nodes and edges preserved."""
Log Probability Computation
def log_prob_sum(self, site_filter=allow_all_sites):
"""
Compute total log probability across all sample sites.
Each log_prob_sum is a scalar. Computation is memoized.
"""
result = 0.0
for name, site in self.nodes.items():
if site["type"] == "sample" and site_filter(name, site):
if "log_prob_sum" in site:
log_p = site["log_prob_sum"]
else:
log_p = site["fn"].log_prob(site["value"], ...)
log_p = scale_and_mask(log_p, site["scale"], site["mask"]).sum()
site["log_prob_sum"] = log_p
result = result + log_p
return result
def compute_log_prob(self, site_filter=allow_all_sites) -> None:
"""Compute per-site log_prob (batched). Memoized."""
def compute_score_parts(self) -> None:
"""Compute batched local score parts for gradient estimation."""
Node Properties
@property
def observation_nodes(self) -> List[str]:
"""List of names of observed sample sites."""
@property
def stochastic_nodes(self) -> List[str]:
"""List of names of unobserved sample sites."""
@property
def param_nodes(self) -> List[str]:
"""List of names of param sites."""
@property
def reparameterized_nodes(self) -> List[str]:
"""List of names of sample sites with reparameterizable distributions."""
@property
def nonreparam_stochastic_nodes(self) -> List[str]:
"""List of names of non-reparameterizable sample sites."""
Tensor Packing and Shape Formatting
def symbolize_dims(self, plate_to_symbol=None) -> None:
"""Assign unique symbols to all tensor dimensions."""
def pack_tensors(self, plate_to_symbol=None) -> None:
"""Compute packed representations for efficient enumeration."""
def format_shapes(self, title="Trace Shapes:", last_site=None) -> str:
"""Return a formatted table of shapes at all sites (for debugging)."""
I/O Contract
| Method | Input | Output |
|---|---|---|
| __init__(graph_type) | "flat" or "dense"
|
A new empty Trace instance
|
| add_node(name, **kwargs) | Site name and Message fields as keyword arguments | None (mutates trace in place)
|
| log_prob_sum(site_filter) | Optional site filter function (name, site) -> bool
|
torch.Tensor scalar (total log probability)
|
| compute_log_prob(site_filter) | Optional site filter function | None (populates log_prob and log_prob_sum fields in each node)
|
| compute_score_parts() | (none) | None (populates score_parts, log_prob, log_prob_sum in each node)
|
| topological_sort(reverse) | Optional bool for reverse order
|
List[str] of site names in topological order
|
| format_shapes(title, last_site) | Optional title string and last site name | A formatted str table showing shapes at all sites
|
| copy() | (none) | A shallow copy of the Trace
|
| detach_() | (none) | None (detaches values in place at each sample site)
|
Usage Examples
Creating and Inspecting a Trace
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
def model(x):
s = pyro.param("s", torch.tensor(0.5))
z = pyro.sample("z", dist.Normal(x, s))
return z ** 2
trace = poutine.trace(model).get_trace(0.0)
logp = trace.log_prob_sum()
params = [trace.nodes[name]["value"].unconstrained()
for name in trace.param_nodes]
Inspecting Trace Nodes
# Node names
list(trace.nodes.keys()) # ["_INPUT", "s", "z", "_RETURN"]
# Node metadata
trace.nodes["z"]
# {'type': 'sample', 'name': 'z', 'is_observed': False,
# 'fn': Normal(), 'value': tensor(0.6480), ...}
# Categorized node lists
trace.stochastic_nodes # ["z"]
trace.param_nodes # ["s"]
trace.observation_nodes # []
Debugging with format_shapes
print(trace.format_shapes())
# Trace Shapes:
# Param Sites:
# s
# Sample Sites:
# z dist |
# value |
Related Pages
- Pyro_ppl_Pyro_Poutine_Runtime -- Defines the
Messagetype stored in trace nodes - Pyro_ppl_Pyro_Poutine_Handlers -- The
poutine.trace()handler that populates Trace objects - Pyro_ppl_Pyro_Messenger_Base -- Base Messenger class whose
_postprocess_messagepopulates traces - Pyro_ppl_Pyro_IndepMessenger -- Provides
CondIndepStackFramedata stored in traces - Pyro_ppl_Pyro_GuideMessenger -- Uses
Traceto extract model/guide trace pairs
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment