Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Sgl project Sglang CPU GEMM FP8

From Leeroopedia


Knowledge Sources
Domains GEMM, Quantization, CPU Compute
Last Updated 2026-02-10 00:00 GMT

Overview

Implements FP8 (Float8_e4m3fn) weight-only quantized GEMM for CPU inference, where FP8 weights are dequantized to BFloat16 on-the-fly during matrix multiplication.

Description

gemm_fp8.cpp provides the CPU implementation of W8A16 (8-bit weight, 16-bit activation) quantized GEMM using the FP8 (Float8_e4m3fn) data format. This enables approximately 2x memory reduction compared to BFloat16 weights while maintaining BF16 compute precision.

The core unpack_B function dequantizes FP8 weights from VNNI-packed format to BFloat16 using AVX-512 intrinsics:

  • Weights are stored in [K/2, N, 2] VNNI layout
  • FP8 values are loaded as uint16 pairs via _mm512_loadu_si512
  • Conversion pipeline: FP8 -> BF16 (via CVT_FP8_TO_BF16_EXT) -> FP32 (for scaling via CVT_BF16_TO_FP32) -> BF16 (for AMX compute via _mm512_cvtne2ps_pbh)
  • Per-tensor scale is applied as float multiplication: _mm512_mul_ps(value, scale * fp8_bias)
  • Block size is fixed at 32 (matching AMX tile width)
  • Prefetch distance of 64 iterations hides memory latency

Additional helper functions:

  • copy_stub: Converts float accumulator results to scalar_t output with SIMD vectorization
  • copy_add_stub: Fused bias addition and type conversion in a single pass

The dequantized BFloat16 output is then consumed by standard BF16 AMX brgemm kernels for the actual matrix multiplication.

Usage

Use this GEMM variant when serving models with FP8 quantized weights on CPU. The FP8 weights must be pre-packed in VNNI format. This is the CPU equivalent of GPU FP8 quantization schemes.

Code Reference

Source Location

Signature

// Dequantize FP8 weights from VNNI format to BFloat16
inline void unpack_B(
    at::BFloat16* __restrict__ Btmp,
    const at::Float8_e4m3fn* __restrict__ packed_B,
    int N, int K,
    int ldb, int ldb_tmp,
    float scale);

// Float-to-scalar conversion with SIMD
template <typename scalar_t>
inline void copy_stub(
    scalar_t* __restrict__ out,
    const float* __restrict__ input,
    int64_t size);

// Fused bias addition and type conversion
template <typename scalar_t>
inline void copy_add_stub(
    scalar_t* __restrict__ out,
    const float* __restrict__ input,
    const float* __restrict__ bias,
    int64_t size);

Import

#include "common.h"
#include "gemm.h"
#include "vec.h"

I/O Contract

Inputs

Name Type Required Description
packed_B Float8_e4m3fn* Yes FP8 weights in VNNI-packed format [K/2, N, 2]
N int Yes Number of output channels
K int Yes Number of input channels
ldb int Yes Leading dimension of packed FP8 weight buffer
ldb_tmp int Yes Leading dimension of output BFloat16 buffer
scale float Yes Per-tensor dequantization scale factor

Outputs

Name Type Description
Btmp BFloat16* Dequantized weights in BFloat16 VNNI format, ready for AMX brgemm

Usage Examples

FP8 Weight Dequantization

// Dequantize FP8 VNNI-packed weights to BFloat16 for AMX compute
at::BFloat16 Btmp[BLOCK_N * K];
unpack_B(
    Btmp,                // output: BF16 weights
    fp8_packed_weight,   // input: FP8 weights in VNNI format
    BLOCK_N,             // N: output channels in this block
    K,                   // K: input channels
    ldb,                 // leading dim of FP8 buffer
    ldb_tmp,             // leading dim of BF16 buffer
    weight_scale         // per-tensor scale
);
// Btmp is now ready for BF16 AMX brgemm

Fused Bias Addition

// Convert float accumulator to BFloat16 with bias in a single pass
float accum[N];
float bias[N];
at::BFloat16 output[N];
copy_add_stub<at::BFloat16>(output, accum, bias, N);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment