Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro LazyJIT

From Leeroopedia


Property Value
Module pyro.ops.jit
Source pyro/ops/jit.py
Lines 164
Classes CompiledFunction
Functions trace
Dependencies torch, pyro, pyro.poutine, pyro.util

Overview

This module provides a lazy replacement for torch.jit.trace that works with Pyro functions calling pyro.param. The standard torch.jit.trace cannot handle Pyro's dynamic parameter management, where parameters are looked up from a global parameter store. This module solves that by automatically capturing parameters during the first invocation and plumbing them through as explicit arguments to the traced function.

The key class is CompiledFunction, which wraps a Pyro function. On first call, it traces all parameters, creates a JIT-compiled version that takes both parameters and regular arguments, and caches the compiled function keyed by argument shape/type. Subsequent calls reuse the compiled function.

Code Reference

Function: trace

@pyro.ops.jit.trace
def model_log_prob_fn(x, y):
    cond_model = pyro.condition(model, data={"y": y})
    tr = pyro.poutine.trace(cond_model).get_trace(x)
    return tr.log_prob_sum()

Can be used as a decorator or called directly. Accepts optional ignore_warnings and jit_options parameters.

Class: CompiledFunction

Wraps the output of torch.jit.trace with parameter management.

Attributes:

  • fn: The original function.
  • compiled: Dict mapping argument signature keys to compiled callables.
  • compile_time: Optional compilation timing (if jit_options["time_compilation"]=True).

Behavior on first call:

  1. Uses poutine.trace(param_only=True) to capture all parameter names.
  2. Retrieves unconstrained parameter values from the Pyro param store.
  3. Prepends parameters to the argument list.
  4. Creates a wrapper function that replays parameters and calls the original function.
  5. Traces the wrapper with torch.jit.trace (with check_trace=False by default).

Behavior on subsequent calls:

  1. Retrieves the cached compiled function for the argument signature.
  2. Prepends current unconstrained parameter values.
  3. Calls the compiled function.
  4. Verifies no new parameters were created.

I/O Contract

Function/Method Input Output
trace(fn) fn: callable (Pyro model/function) CompiledFunction
CompiledFunction.__call__ *args, **kwargs (same as original fn) Same return type as original fn

Usage Examples

import torch
import pyro
import pyro.distributions as dist
from torch.distributions import constraints

def model(x):
    scale = pyro.param("scale", torch.tensor(0.5), constraint=constraints.positive)
    return pyro.sample("y", dist.Normal(x, scale))

# Decorate a function that computes log probability
@pyro.ops.jit.trace
def model_log_prob_fn(x, y):
    cond_model = pyro.condition(model, data={"y": y})
    tr = pyro.poutine.trace(cond_model).get_trace(x)
    return tr.log_prob_sum()

# First call triggers compilation
x = torch.tensor(0.0)
y = torch.tensor(1.0)
log_prob = model_log_prob_fn(x, y)

# Subsequent calls use compiled version
log_prob = model_log_prob_fn(x, y)

# Access compilation artifact
print(type(model_log_prob_fn.compiled))

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment