Overview
Portable SIMD abstraction layer providing unified AVX2/AVX512 intrinsics interface for CPU-optimized deep learning operations.
Description
This header provides a comprehensive SIMD abstraction layer that unifies AVX2 (256-bit) and AVX512 (512-bit) vector operations under a common interface. It includes macros for all essential operations (load, store, arithmetic, logical), type-safe template functions for mixed-precision operations (FP32, FP16, BFloat16), and specialized conversion routines for BFloat16 with proper rounding. The abstraction enables writing optimizer code once while automatically leveraging the best available SIMD instruction set (AVX512 > AVX256 > scalar), with compile-time selection based on architecture flags.
Usage
Include this header in CPU optimizer implementations to achieve portable SIMD acceleration across different x86-64 architectures with AVX support.
Code Reference
Source Location
Signature
// Core SIMD macros (example for AVX512)
#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d)
#define SIMD_LOAD(x) _mm512_loadu_ps(x)
#define SIMD_SET(x) _mm512_set1_ps(x)
#define SIMD_ADD(x, y) _mm512_add_ps(x, y)
#define SIMD_MUL(x, y) _mm512_mul_ps(x, y)
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
#define SIMD_AND(x, y) _mm512_and_ps(x, y)
#define SIMD_XOR(x, y) _mm512_xor_ps(x, y)
#define SIMD_WIDTH 16 // AVX512: 16 floats, AVX256: 8 floats
// AVX_Data union for type conversion
union AVX_Data {
#if defined(__AVX512__)
__m512 data;
#elif defined(__AVX256__)
__m256 data;
#endif
};
// Template functions for mixed precision
template <int span, typename T>
void simd_load(AVX_Data* dst, T* src);
template <int span, typename T>
void simd_store(T* dst, AVX_Data* src);
template <int span>
void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data* src_a);
template <int span>
void simd_add(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r);
template <int span>
void simd_mul(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r);
template <int span>
void simd_div(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r);
template <int span>
void simd_sqrt(AVX_Data* dst, AVX_Data* src);
// BFloat16 conversion (AVX512 only)
#define SIMD_LOAD_BF16(x) load_16_bf16_as_f32(x)
#define SIMD_STORE_BF16(x, d) store_16_f32_as_bf16_nearest(d, x)
// FP16 conversion
#define SIMD_LOAD_FP16(x) _mm512_cvtph_ps(_mm256_castps_si256(_mm256_loadu_ps(x)))
#define SIMD_STORE_FP16(x, d) _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
Import
I/O Contract
SIMD Width Constants
| Constant |
AVX512 |
AVX256 |
Description
|
| SIMD_WIDTH |
16 |
8 |
Number of FP32 elements per vector
|
| TILE |
128MB |
128MB |
Processing tile size for cache optimization
|
Template Functions
| Function |
Parameters |
Description
|
| simd_load<span, T> |
AVX_Data* dst, T* src |
Load span vectors from src to dst
|
| simd_store<span, T> |
T* dst, AVX_Data* src |
Store span vectors from src to dst
|
| simd_fma |
dst, src_m_l, src_m_r, src_a |
dst = src_m_l * src_m_r + src_a
|
| simd_add |
dst, src_a_l, src_a_r |
dst = src_a_l + src_a_r
|
| simd_mul |
dst, src_a_l, src_a_r |
dst = src_a_l * src_a_r
|
| simd_div |
dst, src_a_l, src_a_r |
dst = src_a_l / src_a_r
|
| simd_sqrt |
dst, src |
dst = sqrt(src)
|
| simd_and |
dst, src_a_l, src_a_r |
dst = src_a_l & src_a_r
|
| simd_xor |
dst, src_a_l, src_a_r |
dst = src_a_l ^ src_a_r
|
Supported Types
| Type |
AVX512 |
AVX256 |
Description
|
| float |
Yes |
Yes |
32-bit floating point
|
| c10::Half |
Yes |
Yes |
16-bit floating point (IEEE 754)
|
| c10::BFloat16 |
Yes |
No |
16-bit brain floating point
|
Span Factors
| Span |
Vectors |
FP32 Elements (AVX512) |
FP32 Elements (AVX256)
|
| 1 |
1 |
16 |
8
|
| 4 |
4 |
64 |
32
|
| 8 |
8 |
128 |
64
|
Usage Examples
#include "simd.h"
// Example: Vectorized FMA operation (c = a * b + c)
void vectorized_fma(float* a, float* b, float* c, size_t size) {
constexpr int span = 4;
size_t simd_size = ROUND_DOWN(size, SIMD_WIDTH * span);
for (size_t i = 0; i < simd_size; i += SIMD_WIDTH * span) {
AVX_Data a_vec[span];
AVX_Data b_vec[span];
AVX_Data c_vec[span];
// Load vectors
simd_load<span>(a_vec, a + i);
simd_load<span>(b_vec, b + i);
simd_load<span>(c_vec, c + i);
// Fused multiply-add: c = a * b + c
simd_fma<span>(c_vec, a_vec, b_vec, c_vec);
// Store result
simd_store<span>(c + i, c_vec);
}
// Handle remaining elements (scalar fallback)
for (size_t i = simd_size; i < size; i++) {
c[i] = a[i] * b[i] + c[i];
}
}
// Example: Mixed precision load/store (FP16 to FP32)
void convert_fp16_to_fp32(c10::Half* src, float* dst, size_t size) {
constexpr int span = 1;
size_t simd_size = ROUND_DOWN(size, SIMD_WIDTH);
for (size_t i = 0; i < simd_size; i += SIMD_WIDTH) {
AVX_Data vec[span];
simd_load<span>(vec, src + i); // Automatically converts FP16->FP32
simd_store<span>(dst + i, vec); // Store as FP32
}
}
Related Pages