Implementation:Sgl project Sglang CPU GEMM INT4
| Knowledge Sources | |
|---|---|
| Domains | GEMM, Quantization, CPU Compute |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Implements INT4 weight-only quantized GEMM (W4A8) for CPU inference, where 4-bit weights are dequantized and multiplied with dynamically quantized 8-bit activations using Intel AMX int8 instructions.
Description
gemm_int4.cpp provides the CPU implementation of W4A8 (4-bit weight, 8-bit activation) quantized GEMM, enabling approximately 4x memory compression compared to BFloat16 weights. This is critical for running large models (70B+ parameters) on CPU-only deployments.
The implementation uses a two-stage approach:
Stage 1: INT4 Weight Dequantization
The _dequant_weight_zp_only function unpacks INT4 weights from packed uint8 storage (two int4 values per byte) and subtracts zero points:
- load_uint4_as_int8: Extracts high and low nibbles from packed uint8 using _mm256_srli_epi16 and _mm256_and_si256 with mask 0x0f
- load_zps_4vnni: Broadcasts and shuffles zero points into VNNI-compatible layout using _mm256_set1_epi64x and _mm256_shuffle_epi8
- Zero-point subtraction via _mm256_sub_epi8
Stage 2: INT8 GEMM with AMX
The tinygemm_kernel_nn struct provides AVX-512 VNNI-accelerated GEMM using _mm512_dpbusd_epi32 for uint8-times-int8 dot products with int32 accumulation. Block sizes are BLOCK_M=128 and BLOCK_N=block_size_n().
Output Rescaling
The _dequant_and_store function handles the output dequantization pipeline:
- Subtracts zero-point compensation (comp_b)
- Applies activation scale (scale_a) and weight scale (scale_b)
- Supports both symmetric and asymmetric activation quantization via the sym_quant_act template parameter
Usage
Use this GEMM variant for serving models with INT4 (e.g., GPTQ, AWQ) quantized weights on CPU. Activations are dynamically quantized to uint8/int8 at runtime.
Code Reference
Source Location
- Repository: Sgl_project_Sglang
- File: sgl-kernel/csrc/cpu/gemm_int4.cpp
- Lines: 1-811
Signature
// Activation type selection based on quantization scheme
template <bool sym_quant_act>
struct ActDtype;
template <> struct ActDtype<true> { using type = int8_t; };
template <> struct ActDtype<false> { using type = uint8_t; };
// Load and unpack INT4 weights from packed uint8
inline std::array<m256i_wrapper, 2> load_uint4_as_int8(
const uint8_t* __restrict__ qB);
// Load and broadcast zero points for VNNI layout
inline std::array<m256i_wrapper, 2> load_zps_4vnni(
const int8_t* __restrict__ zps);
// Dequantize INT4 weights with zero-point subtraction
template <int64_t N, int64_t ldb>
void _dequant_weight_zp_only(
const uint8_t* __restrict__ B,
int8_t* dqB,
const int8_t* __restrict__ qzeros,
int64_t K);
// Output rescaling after int8 GEMM
template <bool accum, int64_t N, bool sym_quant_act>
void _dequant_and_store(
float* __restrict__ output,
const int32_t* __restrict__ input,
const float* __restrict__ scale_a,
const int32_t* __restrict__ zp_a,
const float* __restrict__ scale_b,
const int32_t* __restrict__ comp_b,
int M, int ldi, int ldo, int ldsa = 1);
// Core INT4 GEMM micro-kernel
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn {
static inline void apply(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
scalar_t* __restrict__ C,
const float* __restrict__ As,
const float* __restrict__ Bs,
const int32_t* __restrict__ Bcomp,
const float* __restrict__ bias,
int64_t K, int64_t lda, int64_t ldb, int64_t ldc);
};
Import
#include <torch/all.h>
#include "gemm.h"
#include "vec.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| A | uint8_t* | Yes | Quantized activation matrix (uint8 for asymmetric, int8 for symmetric) |
| B | uint8_t* | Yes | Packed INT4 weight matrix (two int4 values per byte), shape [K, N/2] |
| qzeros | int8_t* | Yes | Per-group zero points for INT4 weights |
| scale_a | float* | Yes | Per-token activation quantization scales |
| scale_b | float* | Yes | Per-channel weight quantization scales |
| zp_a | int32_t* | Conditional | Activation zero points (only for asymmetric quantization) |
| comp_b | int32_t* | Yes | Weight compensation values for zero-point correction |
| bias | float* | No | Optional bias vector |
| K | int64_t | Yes | Shared dimension (input channels) |
| M | int | Yes | Number of rows in activation (tokens) |
| N | int64_t | Yes | Number of columns in weight (output channels) |
Outputs
| Name | Type | Description |
|---|---|---|
| C | scalar_t* | Output matrix in BFloat16, shape [M, N] |
| output | float* | Rescaled float output from _dequant_and_store |
Usage Examples
INT4 Weight Dequantization
// Dequantize INT4 weights with zero-point subtraction
int8_t dequantized_weight[K * N];
_dequant_weight_zp_only</*N=*/32, /*ldb=*/16>(
packed_int4_weight, // B: packed uint8 [K, N/2]
dequantized_weight, // dqB: output int8 [K, N]
zero_points, // qzeros: per-group zero points
K // K: input channels
);
INT4 GEMM with Rescaling
// Perform W4A8 GEMM
tinygemm_kernel_nn<at::BFloat16, /*has_bias=*/false, /*BLOCK_M=*/4, /*BLOCK_N=*/32>::apply(
quant_activations, dequant_weights, output,
act_scales, weight_scales, weight_comp, nullptr,
K, lda, ldb, ldc);