Implementation:Pyro ppl Pyro Pyro Deterministic
| Knowledge Sources | |
|---|---|
| Domains | Probabilistic_Programming, Bayesian_Inference |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Concrete tool for recording deterministic computations in Pyro model execution traces.
Description
pyro.deterministic records a named deterministic value in the execution trace. Internally it creates a sample site with a Delta distribution, but with zero entropy so it does not affect the ELBO. This makes derived quantities accessible to the Predictive class and MCMC diagnostics.
Usage
Use pyro.deterministic when you have a computed quantity (a deterministic function of latent variables) that you want to include in posterior samples. Common in MCMC workflows for tracking predicted means, transformed parameters, or summary statistics.
Code Reference
Source Location
- Repository: pyro
- File: pyro/primitives.py
- Lines: L221-246
Signature
def deterministic(
name: str,
value: torch.Tensor,
event_dim: Optional[int] = None,
) -> torch.Tensor:
"""
Record a deterministic value in the trace.
Args:
name: name of the deterministic site
value: the deterministic value to record
event_dim: optional number of rightmost event dimensions
Returns:
The input value unchanged
"""
Import
import pyro
# Used as: pyro.deterministic(name, value)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| name | str | Yes | Unique name for the deterministic site |
| value | torch.Tensor | Yes | Computed value to record |
| event_dim | Optional[int] | No | Number of rightmost event dimensions |
Outputs
| Name | Type | Description |
|---|---|---|
| return | torch.Tensor | The same input value (pass-through) |
Usage Examples
Tracking Predicted Mean
import pyro
import pyro.distributions as dist
def model(X, y=None):
weight = pyro.sample("weight", dist.Normal(0., 1.).expand([X.shape[1]]).to_event(1))
bias = pyro.sample("bias", dist.Normal(0., 10.))
mean = X @ weight + bias
pyro.deterministic("predicted_mean", mean)
sigma = pyro.sample("sigma", dist.HalfNormal(1.))
with pyro.plate("data", X.shape[0]):
pyro.sample("obs", dist.Normal(mean, sigma), obs=y)