Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro LKJ

From Leeroopedia


Knowledge Sources
Domains Probability_Distributions
Last Updated 2026-02-09 09:00 GMT

Overview

Description

LKJ is a distribution over correlation matrices implemented in Pyro. The LKJ distribution is parameterized by a dimension dim and a concentration parameter concentration (often denoted as eta). The probability of a correlation matrix M is proportional to:

det(M) ** (concentration - 1)

The distribution is constructed as a TransformedDistribution that combines a base LKJCholesky distribution (which produces lower-triangular Cholesky factors of correlation matrices) with the inverse of CorrMatrixCholeskyTransform. This transform maps Cholesky factors back to full correlation matrices.

The concentration parameter controls the distribution's behavior:

  • When concentration == 1, the distribution is uniform over all valid correlation matrices.
  • When concentration > 1, the distribution favors correlation matrices with large determinants (i.e., variables that are less correlated).
  • When concentration < 1, the distribution favors correlation matrices with small determinants (i.e., variables that are more correlated).

The file also contains a deprecated LKJCorrCholesky class that emits a FutureWarning and delegates to the standard LKJCholesky distribution.

The mean property returns an identity matrix of the appropriate dimension, which is the expected value of the LKJ distribution representing no correlation.

Usage

The LKJ distribution is commonly used as a prior over correlation matrices in Bayesian models. It is especially useful in multivariate models where one needs to infer the correlation structure between variables while ensuring the resulting matrix is a valid correlation matrix (symmetric, positive definite, with unit diagonal).

Code Reference

Source Location

Signature

class LKJ(TransformedDistribution):
    def __init__(self, dim, concentration=1.0, validate_args=None)

Import

from pyro.distributions import LKJ

I/O Contract

Inputs

Parameter Type Description
dim int The dimension of the correlation matrices, i.e., the number of variables. Produces matrices of shape (dim, dim).
concentration float or torch.Tensor Concentration parameter (eta) controlling the distribution shape. Must be positive. Defaults to 1.0 (uniform over correlation matrices).
validate_args bool or None Whether to validate input arguments. Defaults to None.

Outputs

Method Return Type Description
sample(sample_shape) torch.Tensor Returns a correlation matrix of shape sample_shape + batch_shape + (dim, dim).
log_prob(value) torch.Tensor Returns the log probability of a correlation matrix value.
mean torch.Tensor Returns the identity matrix of shape batch_shape + (dim, dim), representing the expected correlation matrix.

Usage Examples

import torch
from pyro.distributions import LKJ

# Create a uniform LKJ distribution over 3x3 correlation matrices
lkj = LKJ(dim=3, concentration=1.0)

# Sample a correlation matrix
corr_matrix = lkj.sample()
print(corr_matrix.shape)  # torch.Size([3, 3])
print(corr_matrix)         # A valid 3x3 correlation matrix

# Compute log probability
log_p = lkj.log_prob(corr_matrix)
print(log_p.shape)  # torch.Size([])
import pyro
import pyro.distributions as dist

# Using LKJ as a prior over correlation matrices in a Pyro model
def model(data):
    # Prior favoring less correlated variables (concentration > 1)
    corr = pyro.sample("corr", dist.LKJ(dim=4, concentration=2.0))

    # Use the correlation matrix to build a covariance matrix
    scale = pyro.sample("scale", dist.HalfCauchy(torch.ones(4)).to_event(1))
    scale_diag = torch.diag(scale)
    cov = scale_diag @ corr @ scale_diag

    with pyro.plate("obs", len(data)):
        pyro.sample("x", dist.MultivariateNormal(torch.zeros(4), cov), obs=data)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment