Implementation:Pyro ppl Pyro Pyro Sample
| Knowledge Sources | |
|---|---|
| Domains | Probabilistic_Programming, Bayesian_Inference |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Concrete tool for declaring random variables in probabilistic models provided by the Pyro library.
Description
The pyro.sample primitive is the core building block of Pyro probabilistic programs. It declares a named random variable drawn from a specified probability distribution. The Pyro runtime intercepts sample calls via its effect handler (messenger) stack, allowing inference algorithms to modify behavior — for example, recording values in traces, replaying fixed values, or conditioning on observations.
When called outside any handler context, pyro.sample simply draws a sample from the distribution. When called inside an inference context (e.g., during SVI or MCMC), the behavior depends on the active handlers: the value may be recorded, replayed from a guide, conditioned on observed data, or enumerated over.
Usage
Use pyro.sample to declare every random variable in your Pyro model and guide. For observed data, pass the obs keyword argument. For latent variables to be inferred, omit obs. For discrete variables that should be enumerated, set infer={"enumerate": "parallel"}.
Code Reference
Source Location
- Repository: pyro
- File: pyro/primitives.py
- Lines: L125-193
Signature
def sample(
name: str,
fn: TorchDistributionMixin,
*args,
obs: Optional[torch.Tensor] = None,
obs_mask: Optional[torch.BoolTensor] = None,
infer: Optional[InferDict] = None,
**kwargs,
) -> torch.Tensor:
"""
Calls the stochastic function fn with additional side-effects depending
on name and the enclosing context (e.g. an inference algorithm).
Args:
name: name of sample
fn: distribution class or function
obs: observed datum (optional; should only be used in context of inference)
obs_mask: Optional boolean tensor mask broadcastable with fn.batch_shape.
If provided, events with mask=True will be conditioned on obs and
remaining events will be imputed by sampling.
infer: Optional dictionary of inference parameters.
Returns:
sample tensor
"""
Import
import pyro
# Used as: pyro.sample(name, dist, ...)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| name | str | Yes | Unique identifier for the sample site |
| fn | TorchDistributionMixin | Yes | Distribution to sample from (e.g., dist.Normal, dist.Bernoulli) |
| obs | Optional[torch.Tensor] | No | Observed value; makes this an observed site |
| obs_mask | Optional[torch.BoolTensor] | No | Boolean mask for partial observation |
| infer | Optional[InferDict] | No | Inference configuration dict (e.g., {"enumerate": "parallel"}) |
Outputs
| Name | Type | Description |
|---|---|---|
| return | torch.Tensor | Sampled value (or observed value if obs is provided) |
Usage Examples
Basic Latent Variable
import pyro
import pyro.distributions as dist
def model(data):
# Latent variable with Normal prior
mu = pyro.sample("mu", dist.Normal(0., 10.))
sigma = pyro.sample("sigma", dist.HalfNormal(5.))
# Observed data
with pyro.plate("data", len(data)):
pyro.sample("obs", dist.Normal(mu, sigma), obs=data)
Discrete Enumeration
import pyro
import pyro.distributions as dist
def mixture_model(data, K=3):
weights = pyro.sample("weights", dist.Dirichlet(torch.ones(K)))
with pyro.plate("data", len(data)):
# Enumerate over discrete assignments
assignment = pyro.sample("assignment",
dist.Categorical(weights),
infer={"enumerate": "parallel"})
pyro.sample("obs", dist.Normal(locs[assignment], 1.0), obs=data)