Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro Arrowhead

From Leeroopedia


Property Value
Module pyro.ops.arrowhead
Source pyro/ops/arrowhead.py
Lines 128
Named Tuples SymmArrowhead, TriuArrowhead
Functions sqrt, triu_inverse, triu_matvecmul, triu_gram
Status EXPERIMENTAL
Dependencies torch

Overview

This module provides linear algebra operations on arrowhead matrices, which are matrices that are diagonal except for one dense row/column at the top. Arrowhead structure appears in certain MCMC mass matrix adaptations (e.g., in the NUTS algorithm) and enables O(N) operations instead of O(N^3) for general dense matrices.

An arrowhead matrix is represented as a named tuple with:

  • top: The dense top rows of shape (head_size, total_size)
  • bottom_diag: The diagonal of the remaining rows of shape (total_size - head_size,)

Two named tuple types are used:

  • SymmArrowhead(top, bottom_diag): Symmetric arrowhead matrix.
  • TriuArrowhead(top, bottom_diag): Upper-triangular arrowhead matrix.

Code Reference

  • sqrt(x): Computes the upper-triangular square root of a symmetric arrowhead matrix using Schur complement decomposition. The complexity is O(N * head_size^2). Includes retry logic for Cholesky failures by reducing the off-diagonal part.
  • triu_inverse(x): Computes the inverse of an upper-triangular arrowhead matrix. Uses torch.linalg.solve_triangular for the top-left block and computes the top-right block analytically. Complexity is O(N * head_size^2).
  • triu_matvecmul(x, y, transpose): Computes the matrix-vector product of an upper-triangular arrowhead matrix with a vector. Exploits the diagonal structure of the bottom-right part for O(N) complexity.
  • triu_gram(x): Computes the Gram matrix x.T @ x from an upper-triangular arrowhead matrix. The resulting matrix is dense (not arrowhead). Complexity is O(N^2 * head_size).

I/O Contract

Function Input Output
sqrt(x) SymmArrowhead(top, bottom_diag) TriuArrowhead(top, bottom_diag)
triu_inverse(x) TriuArrowhead(top, bottom_diag) TriuArrowhead(top, bottom_diag)
triu_matvecmul(x, y, transpose) TriuArrowhead, Tensor(N,), bool Tensor(N,)
triu_gram(x) TriuArrowhead(top, bottom_diag) Tensor(N, N) (dense matrix)

Usage Examples

import torch
from pyro.ops.arrowhead import SymmArrowhead, sqrt, triu_inverse, triu_matvecmul

# Create a symmetric arrowhead matrix with head_size=2, total_size=5
top = torch.eye(2, 5) + 0.1 * torch.randn(2, 5)
top = top @ top.T  # not quite right -- need proper construction
bottom_diag = torch.ones(3).abs() + 1.0

# Proper symmetric arrowhead: top has shape (head_size, total_size)
head_size = 2
N = 5
A = torch.randn(head_size, head_size)
A = A @ A.T + torch.eye(head_size)  # positive definite block
B = torch.randn(head_size, N - head_size)
top = torch.cat([A, B], dim=-1)
bottom_diag = torch.rand(N - head_size) + 1.0

symm = SymmArrowhead(top=top, bottom_diag=bottom_diag)
triu = sqrt(symm)  # upper-triangular square root

# Inverse
triu_inv = triu_inverse(triu)

# Matrix-vector product
y = torch.randn(N)
result = triu_matvecmul(triu, y)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment