Implementation:Pyro ppl Pyro TensorUtils
Appearance
| Property | Value |
|---|---|
| Module | pyro.ops.tensor_utils
|
| Source | pyro/ops/tensor_utils.py |
| Lines | 504 |
| Functions | as_complex, block_diag_embed, block_diagonal, periodic_repeat, periodic_cumsum, periodic_features, next_fast_len, convolve, repeated_matmul, dct, idct, haar_transform, inverse_haar_transform, safe_cholesky, cholesky_solve, matmul, matvecmul, triangular_solve, precision_to_scale_tril, safe_normalize, broadcast_tensors_without_dim
|
| Dependencies | torch, pyro.settings
|
Overview
This module provides a collection of low-level tensor utility functions used throughout Pyro. It includes:
- Linear algebra helpers: Safe Cholesky decomposition with adaptive jitter, matrix-vector operations with scalar fast paths, precision-to-scale conversion.
- Signal processing: FFT-based convolution, DCT/IDCT (Discrete Cosine Transform), Haar transform, autocorrelation support via
next_fast_len. - Time series utilities: Periodic repeat, periodic cumulative sum, periodic features for seasonality modeling.
- Block matrix operations: Block diagonal embedding and extraction.
A key design feature is that scalar-dimension (size-1) matrices are handled as special fast paths in matmul, matvecmul, triangular_solve, safe_cholesky, and cholesky_solve.
Code Reference
Linear Algebra
safe_cholesky(x): Cholesky decomposition with adaptive jitter controlled byCHOLESKY_RELATIVE_JITTER(configurable viapyro.settings). For 1x1 matrices, returnssqrt(clamp(x)).cholesky_solve(x, y): Cholesky solve with a 1x1 fast path.matmul(x, y): Matrix multiply with 1x1 fast path using element-wisemul.matvecmul(x, y): Matrix-vector multiply with 1x1 fast path.triangular_solve(x, y, upper, transpose): Triangular solve with 1x1 fast path. Handles transposition internally.precision_to_scale_tril(P): Converts a precision matrix to a lower-triangular scale matrix using flipped Cholesky.safe_normalize(x, p): Safely normalizes a vector to the unit sphere, mapping zero vectors to[1, 0, ..., 0].
Signal Processing
next_fast_len(size): Finds the next integer whose prime factors are only 2, 3, or 5 (efficient for FFT). Results are cached.convolve(signal, kernel, mode): FFT-based 1D convolution supporting 'full', 'valid', and 'same' modes.dct(x, dim): Orthonormal Type-II Discrete Cosine Transform, equivalent toscipy.fftpack.dct(norm="ortho").idct(x, dim): Inverse DCT.haar_transform(x): Recursive Haar wavelet transform along the final dimension.inverse_haar_transform(x): Inverse Haar transform.as_complex(x): Converts a real tensor to complex, handling stride alignment.
Time Series
periodic_repeat(tensor, size, dim): Tiles a tensor periodically up to a given size along a dimension. Useful for static seasonality.periodic_cumsum(tensor, period, dim): Computes periodic cumulative sum. Useful for drifting seasonality.periodic_features(duration, max_period, min_period): Creates sin/cos feature matrices for regression-based seasonality modeling.
Block Matrix Operations
block_diag_embed(mat): Converts(..., B, M, N)to block diagonal(..., B*M, B*N).block_diagonal(mat, block_size): Extracts diagonal blocks from a block diagonal matrix.
Other
repeated_matmul(M, n): Computes[M, M^2, ..., M^n]using doubling with O(log n) parallel cost.broadcast_tensors_without_dim(tensors, dim): Broadcasts tensors without changing size along a specified dimension.
I/O Contract
| Function | Input | Output |
|---|---|---|
safe_cholesky(x) |
Tensor(..., N, N) (positive semi-definite) |
Tensor(..., N, N) (lower triangular)
|
precision_to_scale_tril(P) |
Tensor(..., N, N) (precision matrix) |
Tensor(..., N, N) (lower triangular scale)
|
convolve(signal, kernel, mode) |
Tensor(..., M), Tensor(..., N), mode string |
Tensor(..., L) where L depends on mode
|
dct(x, dim) |
Tensor, dimension int |
Tensor (same shape, DCT coefficients)
|
next_fast_len(size) |
Positive int | Int >= size with prime factors in {2,3,5} |
periodic_features(duration, ...) |
duration: int, optional periods |
Tensor(duration, num_features)
|
Usage Examples
import torch
from pyro.ops.tensor_utils import (
safe_cholesky,
convolve,
dct,
idct,
periodic_features,
next_fast_len,
)
# Safe Cholesky with adaptive jitter
P = torch.eye(3) + 0.1 * torch.randn(3, 3)
P = P @ P.T # make positive definite
L = safe_cholesky(P)
# FFT-based convolution
signal = torch.randn(100)
kernel = torch.randn(10)
result = convolve(signal, kernel, mode="same")
print(result.shape) # torch.Size([100])
# DCT round-trip
x = torch.randn(64)
assert torch.allclose(idct(dct(x)), x, atol=1e-5)
# Periodic features for yearly seasonality
features = periodic_features(365, max_period=365.25, min_period=7)
print(features.shape) # torch.Size([365, 102])
# Find FFT-efficient size
print(next_fast_len(100)) # 100 (already efficient: 2^2 * 5^2)
Related Pages
- Pyro_ppl_Pyro_Gaussian -- Uses
safe_cholesky,matmul,matvecmul,triangular_solve - Pyro_ppl_Pyro_GammaGaussian -- Uses
precision_to_scale_tril - Pyro_ppl_Pyro_Stats -- Uses
next_fast_lenfor efficient autocorrelation - Pyro_ppl_Pyro_DCTAdam -- Uses
dctandidctfor frequency-domain optimization - Pyro_ppl_Pyro_Settings -- Controls
CHOLESKY_RELATIVE_JITTER
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment