Implementation:Pyro ppl Pyro Unit
| Attribute | Value |
|---|---|
| Sources | pyro/distributions/unit.py |
| Domains | Probabilistic Programming, Factor Statements, Non-Normalized Distributions |
| Last Updated | 2026-02-09 |
Overview
Description
The Unit distribution is a trivial, non-normalized distribution representing the unit type in type theory. The unit type has a single value with no data, meaning value.numel() == 0. This is achieved by setting event_shape = (0,).
The primary purpose of Unit is to support pyro.factor statements, which allow users to add arbitrary log-probability terms to a model's joint density without corresponding to any observed or latent variable. When pyro.factor("name", log_factor) is called, it internally creates a Unit distribution with the given log_factor and samples from it.
Key characteristics:
- The
log_probmethod returns thelog_factorvalue regardless of the input (broadcasting to match batch shapes) - The
sampleandrsamplemethods return empty tensors withnumel() == 0 - The
has_rsampleflag can be optionally set at construction time to control whether the distribution advertises reparameterized sampling support - The distribution is non-normalized:
log_probdoes not integrate to 0 over the support
Usage
Unit is primarily an internal implementation detail used by pyro.factor. It is rarely constructed directly by users. However, understanding its behavior is important when working with custom inference algorithms or effect handlers that need to process factor statements.
Code Reference
Source Location
| Property | Value |
|---|---|
| File | pyro/distributions/unit.py
|
| Module | pyro.distributions.unit
|
| Repository | pyro-ppl/pyro |
Signature
class Unit(TorchDistribution):
arg_constraints = {"log_factor": constraints.real}
support = constraints.real
def __init__(self, log_factor, *, has_rsample=None, validate_args=None):
...
def expand(self, batch_shape, _instance=None):
...
def sample(self, sample_shape=torch.Size()):
...
def rsample(self, sample_shape=torch.Size()):
...
def log_prob(self, value):
...
Import
from pyro.distributions import Unit
# Or from the module directly:
from pyro.distributions.unit import Unit
I/O Contract
Constructor Parameters
| Parameter | Type | Constraint | Description |
|---|---|---|---|
log_factor |
torch.Tensor or numeric |
constraints.real |
The log-probability factor to contribute to the model's joint density. Its shape determines the batch_shape.
|
has_rsample |
bool or None |
-- | If provided, overrides whether the distribution reports reparameterized sampling support. Keyword-only argument. |
validate_args |
bool or None |
-- | Whether to validate input arguments. Keyword-only argument. Default: None.
|
Distribution Properties
| Property | Value | Description |
|---|---|---|
batch_shape |
log_factor.shape |
Determined by the shape of the log_factor tensor |
event_shape |
torch.Size((0,)) |
Fixed to (0,), satisfying value.numel() == 0
|
support |
constraints.real |
Formally real-valued, though events are empty tensors |
has_rsample |
configurable | Set via constructor parameter; otherwise follows class default |
Methods
| Method | Return Type | Description |
|---|---|---|
sample(sample_shape) |
torch.Tensor |
Returns an empty tensor of shape sample_shape + batch_shape + (0,)
|
rsample(sample_shape) |
torch.Tensor |
Returns an empty tensor of shape sample_shape + batch_shape + (0,)
|
log_prob(value) |
torch.Tensor |
Returns log_factor expanded to match broadcast_shape(batch_shape, value.shape[:-1])
|
expand(batch_shape) |
Unit |
Returns a new Unit instance with log_factor expanded to the given batch shape
|
Usage Examples
Using pyro.factor (the Primary Use Case)
import torch
import pyro
import pyro.distributions as dist
def model(data):
z = pyro.sample("z", dist.Normal(0.0, 1.0))
# Add a custom log-probability factor (internally uses Unit)
pyro.factor("custom_factor", -0.5 * z ** 2)
with pyro.plate("data", len(data)):
pyro.sample("obs", dist.Normal(z, 1.0), obs=data)
Direct Construction of Unit
import torch
import pyro.distributions as dist
# Create a Unit distribution with a scalar log factor
unit = dist.Unit(log_factor=torch.tensor(-2.0))
print(unit.batch_shape) # torch.Size([])
print(unit.event_shape) # torch.Size([0])
# Samples are empty tensors
sample = unit.sample()
print(sample.shape) # torch.Size([0])
print(sample.numel()) # 0
# log_prob returns the log_factor
log_p = unit.log_prob(sample)
print(log_p) # tensor(-2.)
Batched Unit Distribution
import torch
import pyro.distributions as dist
# Batched log factors
log_factors = torch.tensor([-1.0, -2.0, -3.0])
unit = dist.Unit(log_factor=log_factors)
print(unit.batch_shape) # torch.Size([3])
print(unit.event_shape) # torch.Size([0])
# Sample returns empty tensors with correct batch dimensions
samples = unit.sample(torch.Size([10]))
print(samples.shape) # torch.Size([10, 3, 0])
# log_prob returns the batched factors
log_p = unit.log_prob(samples)
print(log_p.shape) # torch.Size([10, 3])
print(log_p[0]) # tensor([-1., -2., -3.])
Controlling has_rsample
import torch
import pyro.distributions as dist
# Default: has_rsample is not explicitly set
unit_default = dist.Unit(log_factor=torch.tensor(0.0))
# Explicitly enable reparameterized sampling
unit_reparam = dist.Unit(log_factor=torch.tensor(0.0), has_rsample=True)
print(unit_reparam.has_rsample) # True
# Explicitly disable reparameterized sampling
unit_no_reparam = dist.Unit(log_factor=torch.tensor(0.0), has_rsample=False)
print(unit_no_reparam.has_rsample) # False
Expanding a Unit Distribution
import torch
import pyro.distributions as dist
unit = dist.Unit(log_factor=torch.tensor(-1.5))
expanded = unit.expand(torch.Size([4, 3]))
print(expanded.batch_shape) # torch.Size([4, 3])
print(expanded.log_factor) # tensor of shape (4, 3) filled with -1.5
log_p = expanded.log_prob(expanded.sample())
print(log_p.shape) # torch.Size([4, 3])
Related Pages
- Delta -- Another specialized distribution for deterministic values
- Distributions_Init -- Central registry of all Pyro distributions including Unit
- SoftLaplace -- A proper normalized distribution, in contrast to Unit's non-normalized nature
- TransformModule -- Another infrastructure component in Pyro's distribution system