Implementation:Pyro ppl Pyro LazyJIT
| Property | Value |
|---|---|
| Module | pyro.ops.jit
|
| Source | pyro/ops/jit.py |
| Lines | 164 |
| Classes | CompiledFunction
|
| Functions | trace
|
| Dependencies | torch, pyro, pyro.poutine, pyro.util
|
Overview
This module provides a lazy replacement for torch.jit.trace that works with Pyro functions calling pyro.param. The standard torch.jit.trace cannot handle Pyro's dynamic parameter management, where parameters are looked up from a global parameter store. This module solves that by automatically capturing parameters during the first invocation and plumbing them through as explicit arguments to the traced function.
The key class is CompiledFunction, which wraps a Pyro function. On first call, it traces all parameters, creates a JIT-compiled version that takes both parameters and regular arguments, and caches the compiled function keyed by argument shape/type. Subsequent calls reuse the compiled function.
Code Reference
Function: trace
@pyro.ops.jit.trace
def model_log_prob_fn(x, y):
cond_model = pyro.condition(model, data={"y": y})
tr = pyro.poutine.trace(cond_model).get_trace(x)
return tr.log_prob_sum()
Can be used as a decorator or called directly. Accepts optional ignore_warnings and jit_options parameters.
Class: CompiledFunction
Wraps the output of torch.jit.trace with parameter management.
Attributes:
fn: The original function.compiled: Dict mapping argument signature keys to compiled callables.compile_time: Optional compilation timing (ifjit_options["time_compilation"]=True).
Behavior on first call:
- Uses
poutine.trace(param_only=True)to capture all parameter names. - Retrieves unconstrained parameter values from the Pyro param store.
- Prepends parameters to the argument list.
- Creates a wrapper function that replays parameters and calls the original function.
- Traces the wrapper with
torch.jit.trace(withcheck_trace=Falseby default).
Behavior on subsequent calls:
- Retrieves the cached compiled function for the argument signature.
- Prepends current unconstrained parameter values.
- Calls the compiled function.
- Verifies no new parameters were created.
I/O Contract
| Function/Method | Input | Output |
|---|---|---|
trace(fn) |
fn: callable (Pyro model/function) |
CompiledFunction
|
CompiledFunction.__call__ |
*args, **kwargs (same as original fn) |
Same return type as original fn |
Usage Examples
import torch
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
def model(x):
scale = pyro.param("scale", torch.tensor(0.5), constraint=constraints.positive)
return pyro.sample("y", dist.Normal(x, scale))
# Decorate a function that computes log probability
@pyro.ops.jit.trace
def model_log_prob_fn(x, y):
cond_model = pyro.condition(model, data={"y": y})
tr = pyro.poutine.trace(cond_model).get_trace(x)
return tr.log_prob_sum()
# First call triggers compilation
x = torch.tensor(0.0)
y = torch.tensor(1.0)
log_prob = model_log_prob_fn(x, y)
# Subsequent calls use compiled version
log_prob = model_log_prob_fn(x, y)
# Access compilation artifact
print(type(model_log_prob_fn.compiled))
Related Pages
- Pyro_ppl_Pyro_Util -- Provides
ignore_jit_warnings,optional,timed - Pyro_ppl_Pyro_MiniPyro -- Mini Pyro includes a simplified
JitTrace_ELBO