Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro TensorUtils

From Leeroopedia


Property Value
Module pyro.ops.tensor_utils
Source pyro/ops/tensor_utils.py
Lines 504
Functions as_complex, block_diag_embed, block_diagonal, periodic_repeat, periodic_cumsum, periodic_features, next_fast_len, convolve, repeated_matmul, dct, idct, haar_transform, inverse_haar_transform, safe_cholesky, cholesky_solve, matmul, matvecmul, triangular_solve, precision_to_scale_tril, safe_normalize, broadcast_tensors_without_dim
Dependencies torch, pyro.settings

Overview

This module provides a collection of low-level tensor utility functions used throughout Pyro. It includes:

  • Linear algebra helpers: Safe Cholesky decomposition with adaptive jitter, matrix-vector operations with scalar fast paths, precision-to-scale conversion.
  • Signal processing: FFT-based convolution, DCT/IDCT (Discrete Cosine Transform), Haar transform, autocorrelation support via next_fast_len.
  • Time series utilities: Periodic repeat, periodic cumulative sum, periodic features for seasonality modeling.
  • Block matrix operations: Block diagonal embedding and extraction.

A key design feature is that scalar-dimension (size-1) matrices are handled as special fast paths in matmul, matvecmul, triangular_solve, safe_cholesky, and cholesky_solve.

Code Reference

Linear Algebra

  • safe_cholesky(x): Cholesky decomposition with adaptive jitter controlled by CHOLESKY_RELATIVE_JITTER (configurable via pyro.settings). For 1x1 matrices, returns sqrt(clamp(x)).
  • cholesky_solve(x, y): Cholesky solve with a 1x1 fast path.
  • matmul(x, y): Matrix multiply with 1x1 fast path using element-wise mul.
  • matvecmul(x, y): Matrix-vector multiply with 1x1 fast path.
  • triangular_solve(x, y, upper, transpose): Triangular solve with 1x1 fast path. Handles transposition internally.
  • precision_to_scale_tril(P): Converts a precision matrix to a lower-triangular scale matrix using flipped Cholesky.
  • safe_normalize(x, p): Safely normalizes a vector to the unit sphere, mapping zero vectors to [1, 0, ..., 0].

Signal Processing

  • next_fast_len(size): Finds the next integer whose prime factors are only 2, 3, or 5 (efficient for FFT). Results are cached.
  • convolve(signal, kernel, mode): FFT-based 1D convolution supporting 'full', 'valid', and 'same' modes.
  • dct(x, dim): Orthonormal Type-II Discrete Cosine Transform, equivalent to scipy.fftpack.dct(norm="ortho").
  • idct(x, dim): Inverse DCT.
  • haar_transform(x): Recursive Haar wavelet transform along the final dimension.
  • inverse_haar_transform(x): Inverse Haar transform.
  • as_complex(x): Converts a real tensor to complex, handling stride alignment.

Time Series

  • periodic_repeat(tensor, size, dim): Tiles a tensor periodically up to a given size along a dimension. Useful for static seasonality.
  • periodic_cumsum(tensor, period, dim): Computes periodic cumulative sum. Useful for drifting seasonality.
  • periodic_features(duration, max_period, min_period): Creates sin/cos feature matrices for regression-based seasonality modeling.

Block Matrix Operations

  • block_diag_embed(mat): Converts (..., B, M, N) to block diagonal (..., B*M, B*N).
  • block_diagonal(mat, block_size): Extracts diagonal blocks from a block diagonal matrix.

Other

  • repeated_matmul(M, n): Computes [M, M^2, ..., M^n] using doubling with O(log n) parallel cost.
  • broadcast_tensors_without_dim(tensors, dim): Broadcasts tensors without changing size along a specified dimension.

I/O Contract

Function Input Output
safe_cholesky(x) Tensor(..., N, N) (positive semi-definite) Tensor(..., N, N) (lower triangular)
precision_to_scale_tril(P) Tensor(..., N, N) (precision matrix) Tensor(..., N, N) (lower triangular scale)
convolve(signal, kernel, mode) Tensor(..., M), Tensor(..., N), mode string Tensor(..., L) where L depends on mode
dct(x, dim) Tensor, dimension int Tensor (same shape, DCT coefficients)
next_fast_len(size) Positive int Int >= size with prime factors in {2,3,5}
periodic_features(duration, ...) duration: int, optional periods Tensor(duration, num_features)

Usage Examples

import torch
from pyro.ops.tensor_utils import (
    safe_cholesky,
    convolve,
    dct,
    idct,
    periodic_features,
    next_fast_len,
)

# Safe Cholesky with adaptive jitter
P = torch.eye(3) + 0.1 * torch.randn(3, 3)
P = P @ P.T  # make positive definite
L = safe_cholesky(P)

# FFT-based convolution
signal = torch.randn(100)
kernel = torch.randn(10)
result = convolve(signal, kernel, mode="same")
print(result.shape)  # torch.Size([100])

# DCT round-trip
x = torch.randn(64)
assert torch.allclose(idct(dct(x)), x, atol=1e-5)

# Periodic features for yearly seasonality
features = periodic_features(365, max_period=365.25, min_period=7)
print(features.shape)  # torch.Size([365, 102])

# Find FFT-efficient size
print(next_fast_len(100))  # 100 (already efficient: 2^2 * 5^2)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment