Implementation:Pyro ppl Pyro Arrowhead
Appearance
| Property | Value |
|---|---|
| Module | pyro.ops.arrowhead
|
| Source | pyro/ops/arrowhead.py |
| Lines | 128 |
| Named Tuples | SymmArrowhead, TriuArrowhead
|
| Functions | sqrt, triu_inverse, triu_matvecmul, triu_gram
|
| Status | EXPERIMENTAL |
| Dependencies | torch
|
Overview
This module provides linear algebra operations on arrowhead matrices, which are matrices that are diagonal except for one dense row/column at the top. Arrowhead structure appears in certain MCMC mass matrix adaptations (e.g., in the NUTS algorithm) and enables O(N) operations instead of O(N^3) for general dense matrices.
An arrowhead matrix is represented as a named tuple with:
top: The dense top rows of shape(head_size, total_size)bottom_diag: The diagonal of the remaining rows of shape(total_size - head_size,)
Two named tuple types are used:
SymmArrowhead(top, bottom_diag): Symmetric arrowhead matrix.TriuArrowhead(top, bottom_diag): Upper-triangular arrowhead matrix.
Code Reference
sqrt(x): Computes the upper-triangular square root of a symmetric arrowhead matrix using Schur complement decomposition. The complexity is O(N * head_size^2). Includes retry logic for Cholesky failures by reducing the off-diagonal part.
triu_inverse(x): Computes the inverse of an upper-triangular arrowhead matrix. Usestorch.linalg.solve_triangularfor the top-left block and computes the top-right block analytically. Complexity is O(N * head_size^2).
triu_matvecmul(x, y, transpose): Computes the matrix-vector product of an upper-triangular arrowhead matrix with a vector. Exploits the diagonal structure of the bottom-right part for O(N) complexity.
triu_gram(x): Computes the Gram matrixx.T @ xfrom an upper-triangular arrowhead matrix. The resulting matrix is dense (not arrowhead). Complexity is O(N^2 * head_size).
I/O Contract
| Function | Input | Output |
|---|---|---|
sqrt(x) |
SymmArrowhead(top, bottom_diag) |
TriuArrowhead(top, bottom_diag)
|
triu_inverse(x) |
TriuArrowhead(top, bottom_diag) |
TriuArrowhead(top, bottom_diag)
|
triu_matvecmul(x, y, transpose) |
TriuArrowhead, Tensor(N,), bool |
Tensor(N,)
|
triu_gram(x) |
TriuArrowhead(top, bottom_diag) |
Tensor(N, N) (dense matrix)
|
Usage Examples
import torch
from pyro.ops.arrowhead import SymmArrowhead, sqrt, triu_inverse, triu_matvecmul
# Create a symmetric arrowhead matrix with head_size=2, total_size=5
top = torch.eye(2, 5) + 0.1 * torch.randn(2, 5)
top = top @ top.T # not quite right -- need proper construction
bottom_diag = torch.ones(3).abs() + 1.0
# Proper symmetric arrowhead: top has shape (head_size, total_size)
head_size = 2
N = 5
A = torch.randn(head_size, head_size)
A = A @ A.T + torch.eye(head_size) # positive definite block
B = torch.randn(head_size, N - head_size)
top = torch.cat([A, B], dim=-1)
bottom_diag = torch.rand(N - head_size) + 1.0
symm = SymmArrowhead(top=top, bottom_diag=bottom_diag)
triu = sqrt(symm) # upper-triangular square root
# Inverse
triu_inv = triu_inverse(triu)
# Matrix-vector product
y = torch.randn(N)
result = triu_matvecmul(triu, y)
Related Pages
- Pyro_ppl_Pyro_WelfordCovariance --
WelfordArrowheadCovarianceestimates arrowhead covariance - Pyro_ppl_Pyro_TensorUtils -- General linear algebra utilities
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment