Implementation:Pyro ppl Pyro Constraints
| Knowledge Sources | |
|---|---|
| Domains | Probability_Distributions |
| Last Updated | 2026-02-09 09:00 GMT |
Overview
Description
The constraints module extends PyTorch's torch.distributions.constraints with additional constraint objects specific to Pyro. It re-exports all upstream PyTorch constraints and adds the following custom constraint classes and instances:
- _Integer (
integer) -- Constrains values to integers by checking thatvalue % 1 == 0. Marked as discrete (is_discrete = True).
- _Sphere (
sphere) -- Constrains values to lie on the Euclidean unit sphere of any dimension. The check verifies that the L2 norm of the vector equals 1 within a tolerance relative to floating-point epsilon. Hasevent_dim = 1.
- _CorrMatrix (
corr_matrix) -- Constrains values to be valid correlation matrices (positive definite matrices with unit diagonal). Hasevent_dim = 2.
- _OrderedVector (
ordered_vector) -- Constrains values to be real-valued tensors where elements are monotonically increasing along the event dimension. Hasevent_dim = 1.
- _PositiveOrderedVector (
positive_ordered_vector) -- Constrains values to be positive and monotonically increasing along the event dimension.
- _SoftplusPositive (
softplus_positive) -- A variant of thepositiveconstraint intended for use with the softplus transform.
- _SoftplusLowerCholesky (
softplus_lower_cholesky) -- A variant of thelower_choleskyconstraint intended for use with the softplus transform.
- _UnitLowerCholesky (
unit_lower_cholesky) -- Constrains values to be lower-triangular square matrices with all ones on the diagonal. Hasevent_dim = 2.
The module also maintains a deprecated alias corr_cholesky_constraint for corr_cholesky. All constraints (both upstream PyTorch and Pyro-specific) are collected into the __all__ list and sorted for documentation generation.
Usage
Constraints are used throughout Pyro to validate distribution parameters (via arg_constraints dictionaries) and to define the support of distributions. They are also used by transforms and reparameterizers to map unconstrained parameters to constrained spaces. Users typically reference constraint objects (e.g., constraints.positive, constraints.ordered_vector) when defining custom distributions or when configuring automatic parameter constraint handling in Pyro's inference machinery.
Code Reference
Source Location
pyro/distributions/constraints.py
Signature
# Custom constraint classes
class _Integer(Constraint): ...
class _Sphere(Constraint): ...
class _CorrMatrix(Constraint): ...
class _OrderedVector(Constraint): ...
class _PositiveOrderedVector(Constraint): ...
class _SoftplusPositive(type(positive)): ...
class _SoftplusLowerCholesky(type(lower_cholesky)): ...
class _UnitLowerCholesky(Constraint): ...
# Instantiated constraint objects
corr_matrix = _CorrMatrix()
integer = _Integer()
ordered_vector = _OrderedVector()
positive_ordered_vector = _PositiveOrderedVector()
sphere = _Sphere()
softplus_positive = _SoftplusPositive()
softplus_lower_cholesky = _SoftplusLowerCholesky()
unit_lower_cholesky = _UnitLowerCholesky()
Import
from pyro.distributions import constraints
# Or import specific constraint objects
from pyro.distributions.constraints import (
integer,
sphere,
corr_matrix,
ordered_vector,
positive_ordered_vector,
softplus_positive,
softplus_lower_cholesky,
unit_lower_cholesky,
)
I/O Contract
Inputs
| Constraint | Method | Parameter | Type | Description |
|---|---|---|---|---|
| All constraints | check(value) |
value |
torch.Tensor |
The tensor value to check against the constraint. |
Outputs
| Method | Return Type | Description |
|---|---|---|
check(value) |
torch.Tensor (boolean) |
A boolean tensor indicating whether each element (or event) satisfies the constraint. |
Usage Examples
import torch
from pyro.distributions import constraints
# Check if a value is an integer
val = torch.tensor([1.0, 2.5, 3.0])
print(constraints.integer.check(val)) # tensor([True, False, True])
# Check if a vector lies on the unit sphere
vec = torch.tensor([0.6, 0.8])
print(constraints.sphere.check(vec)) # True (approximately)
# Check if a vector is ordered (monotonically increasing)
ordered = torch.tensor([1.0, 2.0, 3.0])
unordered = torch.tensor([1.0, 3.0, 2.0])
print(constraints.ordered_vector.check(ordered)) # True
print(constraints.ordered_vector.check(unordered)) # False
# Check a correlation matrix
corr = torch.eye(3)
print(constraints.corr_matrix.check(corr)) # True
# Use constraints in a distribution's arg_constraints
from pyro.distributions.torch_distribution import TorchDistribution
class MyDist(TorchDistribution):
arg_constraints = {
"loc": constraints.real,
"scale": constraints.positive,
"ordering": constraints.ordered_vector,
}
Related Pages
- Pyro_ppl_Pyro_Distribution_Base -- Base distribution class that uses constraints for parameter validation
- Pyro_ppl_Pyro_AVFMultivariateNormal -- Distribution that declares arg_constraints using this module
- Pyro_ppl_Pyro_ConjugateDistributions -- Distributions that use positive, nonnegative_integer, and other constraints