Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Principle:Pyro ppl Pyro Tensor Operations

From Leeroopedia


Knowledge Sources
Domains Tensor Algebra, Numerical Computing, Probabilistic Computing
Last Updated 2026-02-09 09:00 GMT

Overview

Tensor operations provide specialized manipulation utilities for multi-dimensional arrays, including advanced indexing, packed representations, special mathematical functions, and broadcasting patterns tailored for probabilistic computing.

Description

Probabilistic programming involves extensive computation with multi-dimensional tensors representing batches of distributions, samples, and parameters. Standard tensor libraries provide basic operations, but probabilistic computing requires additional specialized utilities.

Tensor utilities: General-purpose functions for manipulating tensor shapes, performing safe mathematical operations, and handling the batch/event shape conventions of probabilistic programming. These include reshaping operations that respect plate structure, safe log-sum-exp operations, and utilities for aligning tensors with different batch shapes.

Arrowhead matrices: Sparse matrix structures where non-zero elements appear only on the diagonal and in one row and one column (forming an "arrowhead" shape). These arise naturally in hierarchical models and can be factorized much more efficiently than dense matrices. Operations on arrowhead matrices (solve, log-determinant) have O(n) complexity instead of O(n^3).

Packed tensors: A representation for tensors with special structure (e.g., triangular, symmetric, banded) that stores only the unique elements. This reduces memory usage and can speed up operations that exploit the structure. For example, a symmetric n x n matrix has only n(n+1)/2 unique elements.

Special mathematical functions: Functions from mathematical analysis needed for probability distributions but not available in standard tensor libraries. These include the log-gamma function, digamma function, beta function, Bessel functions, and other special functions with numerically stable implementations.

Vindex (vectorized indexing): An advanced indexing utility that extends standard tensor indexing with Pyro-aware broadcasting semantics. It handles the common pattern of gathering values from a tensor using indices that have plate (batch) dimensions, ensuring correct broadcasting behavior.

Usage

Use tensor operations when:

  • Manipulating tensor shapes to conform to batch/event shape conventions.
  • Working with structured matrices (arrowhead, triangular, banded) that can be handled more efficiently than dense matrices.
  • Reducing memory usage by packing structured tensors.
  • Computing special mathematical functions needed for distribution log-probabilities.
  • Performing advanced indexing with proper broadcasting in plate contexts.

Theoretical Basis

Arrowhead matrix operations:

# Arrowhead matrix structure (n x n):
# A = [[d_1, 0, ..., 0, a_1],
#      [0, d_2, ..., 0, a_2],
#      [...                 ],
#      [0, 0, ..., d_{n-1}, a_{n-1}],
#      [a_1, a_2, ..., a_{n-1}, d_n]]

# Only 3n - 2 non-zero elements (vs n^2 for dense)

# Determinant (O(n) instead of O(n^3)):
# det(A) = d_n * product_i d_i - sum_i a_i^2 * product_{j!=i} d_j

# Solve Ax = b (O(n) instead of O(n^3)):
# Uses the Sherman-Morrison-Woodbury formula

Packed tensor representation:

# Symmetric matrix packing:
# Full: n x n matrix with n^2 elements
# Packed: n*(n+1)/2 elements (upper or lower triangle)

# Packing: pack(M) -> v  where v[k] = M[i,j] for i <= j
# k = i*n - i*(i-1)/2 + j - i

# Unpacking: unpack(v) -> M  (reconstruct full symmetric matrix)

# Operations on packed tensors:
# Cholesky, matrix-vector product, log-determinant
# can be implemented directly on packed representation

Vectorized indexing (Vindex):

<synttml lang="python">

  1. Problem: standard indexing breaks with plate dimensions
  2. x has shape (batch, n, d)
  3. idx has shape (batch, m)
  4. Want: x[batch, idx[batch, :], :] -> shape (batch, m, d)
  1. Standard PyTorch: requires manual broadcasting and gather
  2. Vindex: x_vindex = Vindex(x)
  3. result = x_vindex[..., idx, :] # automatic plate broadcasting

</syntaxhighlight>

Special functions:

# Log-gamma function: log Gamma(x)
# Needed for: Beta, Dirichlet, Student-t, etc.
# log Gamma(x) = (x - 0.5) * log(x) - x + 0.5 * log(2*pi) + ...  (Stirling)

# Digamma function: psi(x) = d/dx log Gamma(x)
# Needed for: gradients of distributions involving Gamma functions

# Log-beta function: log B(a, b) = log Gamma(a) + log Gamma(b) - log Gamma(a+b)

# Log-sum-exp (numerically stable):
# logsumexp(x_1, ..., x_n) = max(x) + log(sum_i exp(x_i - max(x)))

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment