Implementation:Pyro ppl Pyro LKJ
| Knowledge Sources | |
|---|---|
| Domains | Probability_Distributions |
| Last Updated | 2026-02-09 09:00 GMT |
Overview
Description
LKJ is a distribution over correlation matrices implemented in Pyro. The LKJ distribution is parameterized by a dimension dim and a concentration parameter concentration (often denoted as eta). The probability of a correlation matrix M is proportional to:
det(M) ** (concentration - 1)
The distribution is constructed as a TransformedDistribution that combines a base LKJCholesky distribution (which produces lower-triangular Cholesky factors of correlation matrices) with the inverse of CorrMatrixCholeskyTransform. This transform maps Cholesky factors back to full correlation matrices.
The concentration parameter controls the distribution's behavior:
- When
concentration == 1, the distribution is uniform over all valid correlation matrices. - When
concentration > 1, the distribution favors correlation matrices with large determinants (i.e., variables that are less correlated). - When
concentration < 1, the distribution favors correlation matrices with small determinants (i.e., variables that are more correlated).
The file also contains a deprecated LKJCorrCholesky class that emits a FutureWarning and delegates to the standard LKJCholesky distribution.
The mean property returns an identity matrix of the appropriate dimension, which is the expected value of the LKJ distribution representing no correlation.
Usage
The LKJ distribution is commonly used as a prior over correlation matrices in Bayesian models. It is especially useful in multivariate models where one needs to infer the correlation structure between variables while ensuring the resulting matrix is a valid correlation matrix (symmetric, positive definite, with unit diagonal).
Code Reference
Source Location
- File:
pyro/distributions/lkj.py - Repository: pyro-ppl/pyro
Signature
class LKJ(TransformedDistribution):
def __init__(self, dim, concentration=1.0, validate_args=None)
Import
from pyro.distributions import LKJ
I/O Contract
Inputs
| Parameter | Type | Description |
|---|---|---|
dim |
int |
The dimension of the correlation matrices, i.e., the number of variables. Produces matrices of shape (dim, dim).
|
concentration |
float or torch.Tensor |
Concentration parameter (eta) controlling the distribution shape. Must be positive. Defaults to 1.0 (uniform over correlation matrices).
|
validate_args |
bool or None |
Whether to validate input arguments. Defaults to None.
|
Outputs
| Method | Return Type | Description |
|---|---|---|
sample(sample_shape) |
torch.Tensor |
Returns a correlation matrix of shape sample_shape + batch_shape + (dim, dim).
|
log_prob(value) |
torch.Tensor |
Returns the log probability of a correlation matrix value.
|
mean |
torch.Tensor |
Returns the identity matrix of shape batch_shape + (dim, dim), representing the expected correlation matrix.
|
Usage Examples
import torch
from pyro.distributions import LKJ
# Create a uniform LKJ distribution over 3x3 correlation matrices
lkj = LKJ(dim=3, concentration=1.0)
# Sample a correlation matrix
corr_matrix = lkj.sample()
print(corr_matrix.shape) # torch.Size([3, 3])
print(corr_matrix) # A valid 3x3 correlation matrix
# Compute log probability
log_p = lkj.log_prob(corr_matrix)
print(log_p.shape) # torch.Size([])
import pyro
import pyro.distributions as dist
# Using LKJ as a prior over correlation matrices in a Pyro model
def model(data):
# Prior favoring less correlated variables (concentration > 1)
corr = pyro.sample("corr", dist.LKJ(dim=4, concentration=2.0))
# Use the correlation matrix to build a covariance matrix
scale = pyro.sample("scale", dist.HalfCauchy(torch.ones(4)).to_event(1))
scale_diag = torch.diag(scale)
cov = scale_diag @ corr @ scale_diag
with pyro.plate("obs", len(data)):
pyro.sample("x", dist.MultivariateNormal(torch.zeros(4), cov), obs=data)
Related Pages
- Pyro_ppl_Pyro_MultivariateStudentT - Multivariate Student's t-distribution parameterized with a scale (Cholesky) matrix
- Pyro_ppl_Pyro_OMTMultivariateNormal - Multivariate Normal distribution that also uses Cholesky parameterization