Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Pyro ppl Pyro Pyro Sample

From Leeroopedia


Knowledge Sources
Domains Probabilistic_Programming, Bayesian_Inference
Last Updated 2026-02-09 00:00 GMT

Overview

Concrete tool for declaring random variables in probabilistic models provided by the Pyro library.

Description

The pyro.sample primitive is the core building block of Pyro probabilistic programs. It declares a named random variable drawn from a specified probability distribution. The Pyro runtime intercepts sample calls via its effect handler (messenger) stack, allowing inference algorithms to modify behavior — for example, recording values in traces, replaying fixed values, or conditioning on observations.

When called outside any handler context, pyro.sample simply draws a sample from the distribution. When called inside an inference context (e.g., during SVI or MCMC), the behavior depends on the active handlers: the value may be recorded, replayed from a guide, conditioned on observed data, or enumerated over.

Usage

Use pyro.sample to declare every random variable in your Pyro model and guide. For observed data, pass the obs keyword argument. For latent variables to be inferred, omit obs. For discrete variables that should be enumerated, set infer={"enumerate": "parallel"}.

Code Reference

Source Location

  • Repository: pyro
  • File: pyro/primitives.py
  • Lines: L125-193

Signature

def sample(
    name: str,
    fn: TorchDistributionMixin,
    *args,
    obs: Optional[torch.Tensor] = None,
    obs_mask: Optional[torch.BoolTensor] = None,
    infer: Optional[InferDict] = None,
    **kwargs,
) -> torch.Tensor:
    """
    Calls the stochastic function fn with additional side-effects depending
    on name and the enclosing context (e.g. an inference algorithm).

    Args:
        name: name of sample
        fn: distribution class or function
        obs: observed datum (optional; should only be used in context of inference)
        obs_mask: Optional boolean tensor mask broadcastable with fn.batch_shape.
            If provided, events with mask=True will be conditioned on obs and
            remaining events will be imputed by sampling.
        infer: Optional dictionary of inference parameters.
    Returns:
        sample tensor
    """

Import

import pyro
# Used as: pyro.sample(name, dist, ...)

I/O Contract

Inputs

Name Type Required Description
name str Yes Unique identifier for the sample site
fn TorchDistributionMixin Yes Distribution to sample from (e.g., dist.Normal, dist.Bernoulli)
obs Optional[torch.Tensor] No Observed value; makes this an observed site
obs_mask Optional[torch.BoolTensor] No Boolean mask for partial observation
infer Optional[InferDict] No Inference configuration dict (e.g., {"enumerate": "parallel"})

Outputs

Name Type Description
return torch.Tensor Sampled value (or observed value if obs is provided)

Usage Examples

Basic Latent Variable

import pyro
import pyro.distributions as dist

def model(data):
    # Latent variable with Normal prior
    mu = pyro.sample("mu", dist.Normal(0., 10.))
    sigma = pyro.sample("sigma", dist.HalfNormal(5.))

    # Observed data
    with pyro.plate("data", len(data)):
        pyro.sample("obs", dist.Normal(mu, sigma), obs=data)

Discrete Enumeration

import pyro
import pyro.distributions as dist

def mixture_model(data, K=3):
    weights = pyro.sample("weights", dist.Dirichlet(torch.ones(K)))

    with pyro.plate("data", len(data)):
        # Enumerate over discrete assignments
        assignment = pyro.sample("assignment",
                                 dist.Categorical(weights),
                                 infer={"enumerate": "parallel"})
        pyro.sample("obs", dist.Normal(locs[assignment], 1.0), obs=data)

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment