Implementation:Pyro ppl Pyro OrderedLogistic
| Knowledge Sources | |
|---|---|
| Domains | Probability_Distributions |
| Last Updated | 2026-02-09 09:00 GMT |
Overview
Description
OrderedLogistic is a distribution class in Pyro that provides an alternative parameterization of a categorical distribution, designed specifically for ordered categorical (ordinal) data. It extends PyTorch's Categorical distribution and is parameterized by two inputs: a predictor tensor and a cutpoints tensor.
The conversion from the ordered logistic parameterization to category probabilities works as follows:
- Cumulative probabilities are computed as
q = sigmoid(cutpoints - predictor), where the predictor is unsqueezed to broadcast against the cutpoints vector. - Category probabilities are derived by differencing the cumulative probabilities. For K cutpoints, there are K+1 categories with probabilities:
p[0] = q[0](probability of category 0)p[k] = q[k] - q[k-1]for1 <= k < K(probability of interior categories)p[K] = 1 - q[K-1](probability of the last category)
- These probabilities are passed to the parent
Categoricalconstructor.
The cutpoints tensor must have monotonically increasing values along its last dimension (enforced by the constraints.ordered_vector constraint). This ensures the cumulative probabilities are well-ordered.
The class also supports the expand method for adjusting the batch shape, properly expanding both the predictor and cutpoints tensors.
Usage
OrderedLogistic is the standard choice for ordinal regression in Bayesian modeling. It is used when the response variable has ordered categories (e.g., survey ratings from "strongly disagree" to "strongly agree", disease severity levels, or credit ratings). The predictor captures the effect of covariates on the ordinal outcome, while the cutpoints define the thresholds between adjacent categories.
Code Reference
Source Location
- File:
pyro/distributions/ordered_logistic.py - Repository: pyro-ppl/pyro
Signature
class OrderedLogistic(Categorical):
def __init__(self, predictor, cutpoints, validate_args=None)
Import
from pyro.distributions import OrderedLogistic
I/O Contract
Inputs
| Parameter | Type | Description |
|---|---|---|
predictor |
torch.Tensor |
A tensor of predictor (linear predictor) values of arbitrary shape. The output shape of non-batched samples matches this shape. |
cutpoints |
torch.Tensor |
A tensor of ordered cutpoints (thresholds). The last dimension must be monotonically increasing. The first ndim-1 dimensions must be broadcastable to the predictor shape. For K cutpoints, the distribution has K+1 categories.
|
validate_args |
bool or None |
Whether to validate input arguments. Defaults to None.
|
Outputs
| Method | Return Type | Description |
|---|---|---|
sample(sample_shape) |
torch.Tensor |
Returns integer category indices (inherited from Categorical).
|
log_prob(value) |
torch.Tensor |
Returns the log probability of category indices (inherited from Categorical).
|
expand(batch_shape) |
OrderedLogistic |
Returns a new instance with expanded batch shape, preserving predictor and cutpoints.
|
Usage Examples
import torch
from pyro.distributions import OrderedLogistic
# 4 cutpoints define 5 ordinal categories
cutpoints = torch.tensor([-2.0, -1.0, 0.0, 1.5])
predictor = torch.tensor(0.5) # scalar predictor
dist = OrderedLogistic(predictor, cutpoints)
# Sample a category
sample = dist.sample()
print(sample) # An integer in {0, 1, 2, 3, 4}
# Compute log probability
log_p = dist.log_prob(torch.tensor(2))
print(log_p)
# Access the underlying category probabilities
print(dist.probs) # 5-element probability vector summing to 1
import pyro
import pyro.distributions as dist
import torch
# Bayesian ordinal regression model
def model(X, y=None):
# Regression coefficient
beta = pyro.sample("beta", dist.Normal(0.0, 1.0))
# Ordered cutpoints with increasing constraint
cutpoints = pyro.sample(
"cutpoints",
dist.TransformedDistribution(
dist.Normal(torch.zeros(3), torch.ones(3)).to_event(1),
dist.transforms.OrderedTransform(),
),
)
# Linear predictor
predictor = X * beta
with pyro.plate("data", len(X)):
pyro.sample("y", dist.OrderedLogistic(predictor, cutpoints), obs=y)
Related Pages
- Pyro_ppl_Pyro_OneTwoMatching - Another discrete distribution with structured constraints
- Pyro_ppl_Pyro_ImproperUniform - A distribution with special support handling, illustrating Pyro's flexible distribution system