Implementation:Vllm project Vllm CPU WNA16
| Knowledge Sources | |
|---|---|
| Domains | CPU_Inference, Quantization, GEMM |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements 4-bit weight-only quantized GEMM (WNA16) for GPTQ and AWQ formats on CPU, with LUT-based dequantization and ISA-specific micro-GEMM dispatch.
Description
This file provides a CPU-optimized weight-N-bit-activation-16-bit (WNA16) matrix multiplication kernel that dequantizes 4-bit packed integer weights on-the-fly during GEMM computation. The Dequantizer4b template class extracts 4-bit values from int32 packed words using a lookup table (LUT) approach with FP32Vec16, applying per-group scales and optional zero-points to support both GPTQ (symmetric) and AWQ (asymmetric) quantization formats. The kernel supports descending activation order (g_idx) for GPTQ and dispatches to ISA-specific micro-GEMM backends (AMX or VEC) with cache-aware N-dimension tiling and OpenMP parallelism.
Usage
This code is compiled as part of the vLLM CPU backend and is invoked when running quantized models (GPTQ/AWQ with 4-bit weights) on CPU. It is called via the cpu_gemm_wna16 torch extension function from the Python quantization layer.
Code Reference
Source Location
- Repository: vllm
- File: csrc/cpu/cpu_wna16.cpp
- Lines: 1-402
Signature
void cpu_gemm_wna16(
const torch::Tensor& input, // [M, K]
const torch::Tensor& q_weight, // [N / 16, K * 16 / pack_factor], packed as int32
torch::Tensor& output, // [M, N]
const torch::Tensor& scales, // [group_num, N]
const std::optional<torch::Tensor>& zeros, // [group_num, N / pack_factor]
const std::optional<torch::Tensor>& g_idx, // [K]
const std::optional<torch::Tensor>& bias, // [N]
const int64_t pack_factor,
const std::string& isa_hint);
template <typename scalar_t, ISA isa, bool has_zp, bool use_desc_act>
class Dequantizer4b {
public:
constexpr static int32_t pack_num = 32 / 4;
static void dequant(int32_t* q_weight, scalar_t* weight,
scalar_t* scales, int32_t* zeros,
int32_t* g_idx, const int64_t scales_stride,
const int64_t zeros_stride, const int32_t k_size,
const int32_t group_size);
};
Import
#include "cpu/cpu_types.hpp"
#include "cpu/utils.hpp"
#include "cpu/micro_gemm/cpu_micro_gemm_vec.hpp"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input | torch::Tensor | Yes | Activation matrix of shape [M, K], BF16 or FP16 |
| q_weight | torch::Tensor | Yes | 4-bit packed weight matrix of shape [N/16, K*16/pack_factor], stored as int32 |
| output | torch::Tensor | Yes | Output matrix of shape [M, N] to write results into |
| scales | torch::Tensor | Yes | Per-group quantization scales of shape [group_num, N] |
| zeros | std::optional<torch::Tensor> | No | Per-group zero-points for AWQ format, shape [group_num, N/pack_factor] |
| g_idx | std::optional<torch::Tensor> | No | Group index mapping for GPTQ descending activation order, shape [K] |
| bias | std::optional<torch::Tensor> | No | Optional bias vector of shape [N] |
| pack_factor | int64_t | Yes | Packing factor (must be 8 for 4-bit quantization) |
| isa_hint | std::string | Yes | ISA dispatch hint: "amx" or "vec" |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch::Tensor | In-place result of the quantized GEMM computation, shape [M, N] |
Usage Examples
// Run 4-bit quantized GEMM on CPU (GPTQ format, no zero-points)
cpu_gemm_wna16(
input, // [M, K] BF16
q_weight, // [N/16, K*2] packed int32
output, // [M, N] BF16
scales, // [group_num, N]
std::nullopt, // no zero-points (GPTQ symmetric)
std::nullopt, // no g_idx
bias, // [N] optional bias
8, // pack_factor for 4-bit
"amx" // use AMX ISA
);
// Run with AWQ format (has zero-points)
cpu_gemm_wna16(
input, q_weight, output, scales,
zeros, // [group_num, N/8] AWQ zero-points
std::nullopt, // no g_idx for AWQ
std::nullopt, // no bias
8, "vec"
);