Implementation:Sgl project Sglang CPU GEMM FP8
| Knowledge Sources | |
|---|---|
| Domains | GEMM, Quantization, CPU Compute |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Implements FP8 (Float8_e4m3fn) weight-only quantized GEMM for CPU inference, where FP8 weights are dequantized to BFloat16 on-the-fly during matrix multiplication.
Description
gemm_fp8.cpp provides the CPU implementation of W8A16 (8-bit weight, 16-bit activation) quantized GEMM using the FP8 (Float8_e4m3fn) data format. This enables approximately 2x memory reduction compared to BFloat16 weights while maintaining BF16 compute precision.
The core unpack_B function dequantizes FP8 weights from VNNI-packed format to BFloat16 using AVX-512 intrinsics:
- Weights are stored in [K/2, N, 2] VNNI layout
- FP8 values are loaded as uint16 pairs via _mm512_loadu_si512
- Conversion pipeline: FP8 -> BF16 (via CVT_FP8_TO_BF16_EXT) -> FP32 (for scaling via CVT_BF16_TO_FP32) -> BF16 (for AMX compute via _mm512_cvtne2ps_pbh)
- Per-tensor scale is applied as float multiplication: _mm512_mul_ps(value, scale * fp8_bias)
- Block size is fixed at 32 (matching AMX tile width)
- Prefetch distance of 64 iterations hides memory latency
Additional helper functions:
- copy_stub: Converts float accumulator results to scalar_t output with SIMD vectorization
- copy_add_stub: Fused bias addition and type conversion in a single pass
The dequantized BFloat16 output is then consumed by standard BF16 AMX brgemm kernels for the actual matrix multiplication.
Usage
Use this GEMM variant when serving models with FP8 quantized weights on CPU. The FP8 weights must be pre-packed in VNNI format. This is the CPU equivalent of GPU FP8 quantization schemes.
Code Reference
Source Location
- Repository: Sgl_project_Sglang
- File: sgl-kernel/csrc/cpu/gemm_fp8.cpp
- Lines: 1-555
Signature
// Dequantize FP8 weights from VNNI format to BFloat16
inline void unpack_B(
at::BFloat16* __restrict__ Btmp,
const at::Float8_e4m3fn* __restrict__ packed_B,
int N, int K,
int ldb, int ldb_tmp,
float scale);
// Float-to-scalar conversion with SIMD
template <typename scalar_t>
inline void copy_stub(
scalar_t* __restrict__ out,
const float* __restrict__ input,
int64_t size);
// Fused bias addition and type conversion
template <typename scalar_t>
inline void copy_add_stub(
scalar_t* __restrict__ out,
const float* __restrict__ input,
const float* __restrict__ bias,
int64_t size);
Import
#include "common.h"
#include "gemm.h"
#include "vec.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| packed_B | Float8_e4m3fn* | Yes | FP8 weights in VNNI-packed format [K/2, N, 2] |
| N | int | Yes | Number of output channels |
| K | int | Yes | Number of input channels |
| ldb | int | Yes | Leading dimension of packed FP8 weight buffer |
| ldb_tmp | int | Yes | Leading dimension of output BFloat16 buffer |
| scale | float | Yes | Per-tensor dequantization scale factor |
Outputs
| Name | Type | Description |
|---|---|---|
| Btmp | BFloat16* | Dequantized weights in BFloat16 VNNI format, ready for AMX brgemm |
Usage Examples
FP8 Weight Dequantization
// Dequantize FP8 VNNI-packed weights to BFloat16 for AMX compute
at::BFloat16 Btmp[BLOCK_N * K];
unpack_B(
Btmp, // output: BF16 weights
fp8_packed_weight, // input: FP8 weights in VNNI format
BLOCK_N, // N: output channels in this block
K, // K: input channels
ldb, // leading dim of FP8 buffer
ldb_tmp, // leading dim of BF16 buffer
weight_scale // per-tensor scale
);
// Btmp is now ready for BF16 AMX brgemm
Fused Bias Addition
// Convert float accumulator to BFloat16 with bias in a single pass
float accum[N];
float bias[N];
at::BFloat16 output[N];
copy_add_stub<at::BFloat16>(output, accum, bias, N);