Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro Unit

From Leeroopedia


Attribute Value
Sources pyro/distributions/unit.py
Domains Probabilistic Programming, Factor Statements, Non-Normalized Distributions
Last Updated 2026-02-09

Overview

Description

The Unit distribution is a trivial, non-normalized distribution representing the unit type in type theory. The unit type has a single value with no data, meaning value.numel() == 0. This is achieved by setting event_shape = (0,).

The primary purpose of Unit is to support pyro.factor statements, which allow users to add arbitrary log-probability terms to a model's joint density without corresponding to any observed or latent variable. When pyro.factor("name", log_factor) is called, it internally creates a Unit distribution with the given log_factor and samples from it.

Key characteristics:

  • The log_prob method returns the log_factor value regardless of the input (broadcasting to match batch shapes)
  • The sample and rsample methods return empty tensors with numel() == 0
  • The has_rsample flag can be optionally set at construction time to control whether the distribution advertises reparameterized sampling support
  • The distribution is non-normalized: log_prob does not integrate to 0 over the support

Usage

Unit is primarily an internal implementation detail used by pyro.factor. It is rarely constructed directly by users. However, understanding its behavior is important when working with custom inference algorithms or effect handlers that need to process factor statements.

Code Reference

Source Location

Property Value
File pyro/distributions/unit.py
Module pyro.distributions.unit
Repository pyro-ppl/pyro

Signature

class Unit(TorchDistribution):
    arg_constraints = {"log_factor": constraints.real}
    support = constraints.real

    def __init__(self, log_factor, *, has_rsample=None, validate_args=None):
        ...

    def expand(self, batch_shape, _instance=None):
        ...

    def sample(self, sample_shape=torch.Size()):
        ...

    def rsample(self, sample_shape=torch.Size()):
        ...

    def log_prob(self, value):
        ...

Import

from pyro.distributions import Unit

# Or from the module directly:
from pyro.distributions.unit import Unit

I/O Contract

Constructor Parameters

Parameter Type Constraint Description
log_factor torch.Tensor or numeric constraints.real The log-probability factor to contribute to the model's joint density. Its shape determines the batch_shape.
has_rsample bool or None -- If provided, overrides whether the distribution reports reparameterized sampling support. Keyword-only argument.
validate_args bool or None -- Whether to validate input arguments. Keyword-only argument. Default: None.

Distribution Properties

Property Value Description
batch_shape log_factor.shape Determined by the shape of the log_factor tensor
event_shape torch.Size((0,)) Fixed to (0,), satisfying value.numel() == 0
support constraints.real Formally real-valued, though events are empty tensors
has_rsample configurable Set via constructor parameter; otherwise follows class default

Methods

Method Return Type Description
sample(sample_shape) torch.Tensor Returns an empty tensor of shape sample_shape + batch_shape + (0,)
rsample(sample_shape) torch.Tensor Returns an empty tensor of shape sample_shape + batch_shape + (0,)
log_prob(value) torch.Tensor Returns log_factor expanded to match broadcast_shape(batch_shape, value.shape[:-1])
expand(batch_shape) Unit Returns a new Unit instance with log_factor expanded to the given batch shape

Usage Examples

Using pyro.factor (the Primary Use Case)

import torch
import pyro
import pyro.distributions as dist

def model(data):
    z = pyro.sample("z", dist.Normal(0.0, 1.0))

    # Add a custom log-probability factor (internally uses Unit)
    pyro.factor("custom_factor", -0.5 * z ** 2)

    with pyro.plate("data", len(data)):
        pyro.sample("obs", dist.Normal(z, 1.0), obs=data)

Direct Construction of Unit

import torch
import pyro.distributions as dist

# Create a Unit distribution with a scalar log factor
unit = dist.Unit(log_factor=torch.tensor(-2.0))
print(unit.batch_shape)  # torch.Size([])
print(unit.event_shape)  # torch.Size([0])

# Samples are empty tensors
sample = unit.sample()
print(sample.shape)   # torch.Size([0])
print(sample.numel())  # 0

# log_prob returns the log_factor
log_p = unit.log_prob(sample)
print(log_p)  # tensor(-2.)

Batched Unit Distribution

import torch
import pyro.distributions as dist

# Batched log factors
log_factors = torch.tensor([-1.0, -2.0, -3.0])
unit = dist.Unit(log_factor=log_factors)
print(unit.batch_shape)  # torch.Size([3])
print(unit.event_shape)  # torch.Size([0])

# Sample returns empty tensors with correct batch dimensions
samples = unit.sample(torch.Size([10]))
print(samples.shape)  # torch.Size([10, 3, 0])

# log_prob returns the batched factors
log_p = unit.log_prob(samples)
print(log_p.shape)  # torch.Size([10, 3])
print(log_p[0])     # tensor([-1., -2., -3.])

Controlling has_rsample

import torch
import pyro.distributions as dist

# Default: has_rsample is not explicitly set
unit_default = dist.Unit(log_factor=torch.tensor(0.0))

# Explicitly enable reparameterized sampling
unit_reparam = dist.Unit(log_factor=torch.tensor(0.0), has_rsample=True)
print(unit_reparam.has_rsample)  # True

# Explicitly disable reparameterized sampling
unit_no_reparam = dist.Unit(log_factor=torch.tensor(0.0), has_rsample=False)
print(unit_no_reparam.has_rsample)  # False

Expanding a Unit Distribution

import torch
import pyro.distributions as dist

unit = dist.Unit(log_factor=torch.tensor(-1.5))
expanded = unit.expand(torch.Size([4, 3]))
print(expanded.batch_shape)    # torch.Size([4, 3])
print(expanded.log_factor)     # tensor of shape (4, 3) filled with -1.5

log_p = expanded.log_prob(expanded.sample())
print(log_p.shape)  # torch.Size([4, 3])

Related Pages

  • Delta -- Another specialized distribution for deterministic values
  • Distributions_Init -- Central registry of all Pyro distributions including Unit
  • SoftLaplace -- A proper normalized distribution, in contrast to Unit's non-normalized nature
  • TransformModule -- Another infrastructure component in Pyro's distribution system

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment