Implementation:Turboderp org Exllamav2 AVX Mathfun
| Knowledge Sources | |
|---|---|
| Domains | SIMD, Math, Performance_Optimization |
| Last Updated | 2026-02-15 00:00 GMT |
Overview
AVX/AVX2 SIMD-vectorized implementations of transcendental math functions (log, exp, sin, cos, sincos) operating on 8 packed single-precision floats simultaneously.
Description
avx_mathfun.h provides high-performance vectorized approximations of common math functions using Intel AVX and AVX2 intrinsics. The implementations are based on the Cephes library polynomial approximations, ported to AVX by Giovanni Garberoglio and further refined for use in ExLlamaV2.
The key functions are:
- log256_ps -- Computes the natural logarithm for 8 floats in parallel using a polynomial expansion. Returns NaN for inputs <= 0.
- exp256_ps -- Computes the exponential function for 8 floats. Clamps inputs to [-88.38, 88.38] to prevent overflow, then uses the identity exp(x) = 2^n * exp(g) with a degree-5 polynomial approximation for the fractional part.
- sin256_ps -- Computes sine for 8 floats using extended-precision Cephes-style range reduction followed by polynomial evaluation on [0, Pi/4].
- cos256_ps -- Computes cosine for 8 floats using the same technique as sin256_ps with adjusted sign and phase logic.
- sincos256_ps -- Computes both sine and cosine simultaneously, returning results via output pointers. Nearly as fast as computing either alone since the polynomial evaluations share intermediate values.
The file defines custom type aliases (v8sf for __m256, v8si for __m256i) and includes fallback paths that emulate AVX2 integer operations using SSE2 pairs when __AVX2__ is not defined. Constants are declared using alignment macros (ALIGN32_BEG / ALIGN32_END) that adapt between GCC and MSVC alignment syntax.
This code is licensed under the zlib license.
Usage
Use these functions when you need fast, vectorized transcendental math on CPU in the ExLlamaV2 sampling or softmax pathways. They are included as a header-only library and called from the AVX2 sampling code to compute log-probabilities and softmax normalization. The functions achieve significant speedups over scalar std::log/std::exp calls by processing 8 floats per instruction.
Code Reference
Source Location
- Repository: Turboderp_org_Exllamav2
- File: exllamav2/exllamav2_ext/cpp/avx_mathfun.h
- Lines: 1-858
Signature
// Type aliases
typedef __m256 v8sf; // vector of 8 float (AVX)
typedef __m256i v8si; // vector of 8 int (AVX)
// Natural logarithm for 8 floats; returns NaN for x <= 0
v8sf log256_ps(v8sf x);
// Exponential for 8 floats; clamps to [-88.38, 88.38]
__m256 exp256_ps(__m256 x);
// Sine for 8 floats (any x, best precision for |x| < 8192)
v8sf sin256_ps(v8sf x);
// Cosine for 8 floats (any x, best precision for |x| < 8192)
v8sf cos256_ps(v8sf x);
// Combined sine and cosine; writes results to *s and *c
void sincos256_ps(v8sf x, v8sf *s, v8sf *c);
Import
// Header-only; included directly in C++ source files
#include "avx_mathfun.h"
I/O Contract
Inputs
| Parameter | Type | Description |
|---|---|---|
| x | v8sf / __m256 |
Packed vector of 8 single-precision floats |
| *s (sincos256_ps) | v8sf* |
Output pointer for sine results |
| *c (sincos256_ps) | v8sf* |
Output pointer for cosine results |
Outputs
| Function | Return Type | Description |
|---|---|---|
| log256_ps | v8sf |
Natural logarithm of each element; NaN for non-positive inputs |
| exp256_ps | __m256 |
Exponential of each element; clamped to prevent overflow |
| sin256_ps | v8sf |
Sine of each element in radians |
| cos256_ps | v8sf |
Cosine of each element in radians |
| sincos256_ps | void | Results written to output pointers *s (sine) and *c (cosine) |
Usage Examples
#include "avx_mathfun.h"
// Compute log-softmax for 8 floats
__m256 logits = _mm256_loadu_ps(logit_ptr);
__m256 max_val = _mm256_set1_ps(max_logit);
__m256 shifted = _mm256_sub_ps(logits, max_val);
__m256 exp_vals = exp256_ps(shifted);
// ... sum exp_vals, then:
__m256 log_sum = log256_ps(_mm256_set1_ps(sum_of_exps));
__m256 log_probs = _mm256_sub_ps(shifted, log_sum);