Implementation:Sgl project Sglang CPU MoE INT4
| Knowledge Sources | |
|---|---|
| Domains | Machine Learning, CPU Kernels, Quantization |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Implements the INT4 weight-quantized (W4A8) variant of the fused Mixture-of-Experts kernel for CPU inference, combining 4-bit weight dequantization with dynamically quantized 8-bit activations.
Description
This file follows the fused MoE two-GEMM pattern with INT4 expert weights packed as uint8 storage (2 values per byte). The INT4 weights are unpacked, dequantized with per-group zero points and scales, and used with dynamically quantized uint8 activations via the INT4 GEMM infrastructure from gemm_int4.cpp. The computation leverages AVX-512 VNNI int8 dot product instructions after weight dequantization.
Key helper functions provided:
- copy_stub -- two overloads: same-type copy and float-to-scalar conversion copy
- copy_mul_stub -- copies with topk weight scaling (float input to scalar output)
- sum_stub -- reduces [topk, K] to [K] by accumulating across top-k experts
- add_mul_stub -- computes out = input + input2 * scale for residual connections
The kernel processes three stages: (1) align and sort tokens by expert assignment, (2) execute per-expert fused GEMM1-SiLU-GEMM2 with INT4 weight dequantization, and (3) accumulate weighted expert outputs. Template instantiations are provided for at::BFloat16 and at::Half.
Usage
Use this kernel when deploying large MoE models on memory-constrained CPU systems. INT4 quantization provides approximately 4x memory reduction per expert weight compared to BFloat16, which is critical for fitting models like Mixtral-8x7B or DeepSeek-V3 in available memory. The W4A8 scheme balances compression ratio with inference quality.
Code Reference
Source Location
- Repository: Sgl_project_Sglang
- File: sgl-kernel/csrc/cpu/moe_int4.cpp
- Lines: 1-484
Signature
template <typename scalar_t>
void fused_experts_int4_w4a8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic0,
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ ic2,
uint8_t* __restrict__ A_tmp,
uint8_t* __restrict__ Aq_tmp,
float* __restrict__ As_tmp,
int32_t* __restrict__ Azp_tmp,
float* __restrict__ C_tmp,
int8_t* __restrict__ dqB_tmp,
const scalar_t* __restrict__ input,
const uint8_t* __restrict__ packed_w1,
const uint8_t* __restrict__ packed_w2,
const int8_t* __restrict__ w1z,
const int8_t* __restrict__ w2z,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
int group_size,
const float* __restrict__ topk_weights,
const int32_t* __restrict__ sorted_ids,
const int32_t* __restrict__ expert_ids,
const int32_t* __restrict__ offsets,
int64_t M, int64_t N, int64_t K,
int64_t E, int64_t topk,
int64_t num_tokens_post_pad);
Import
#include "common.h"
#include "gemm.h"
#include "vec.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input | scalar_t* [M, K] | Yes | Input hidden states (BFloat16 or Half) |
| packed_w1 | uint8_t* [E, 2N, K/2] | Yes | Gate+up projection weights in packed INT4 format (2 values per byte) |
| packed_w2 | uint8_t* [E, K, N/2] | Yes | Down projection weights in packed INT4 format |
| w1z | int8_t* | Yes | Per-group zero points for w1 dequantization |
| w2z | int8_t* | Yes | Per-group zero points for w2 dequantization |
| w1s | float* | Yes | Per-group dequantization scales for w1 |
| w2s | float* | Yes | Per-group dequantization scales for w2 |
| group_size | int | Yes | Number of elements per quantization group |
| topk_weights | float* [M, topk] | Yes | Routing weights for selected experts |
| sorted_ids | int32_t* | Yes | Token indices sorted by expert assignment |
| expert_ids | int32_t* | Yes | Expert assignment per sorted block |
| offsets | int32_t* | Yes | Starting offsets for each M block |
| M, N, K | int64_t | Yes | Matrix dimensions: tokens, intermediate size, hidden size |
| E | int64_t | Yes | Number of experts |
| topk | int64_t | Yes | Number of selected experts per token |
Outputs
| Name | Type | Description |
|---|---|---|
| output | scalar_t* [M, K] | MoE output after INT4-dequantized expert computation and accumulation |
Usage Examples
// INT4 MoE is dispatched via the main fused_experts_cpu() entry point
at::Tensor output = fused_experts_cpu(
hidden_states, // [M, K] BFloat16
w1_int4, // [E, groups, 2, N/2, K] packed INT4
w2_int4, // [E, groups, 2, K/2, N] packed INT4
topk_weights, // [M, topk] float32
topk_ids, // [M, topk] int32
/*inplace=*/false,
/*moe_comp_method=*/CPUQuantMethod::INT4_W4A8,
w1_scale, w2_scale, // per-group scales
w1_zero, w2_zero, // per-group zero points
/*block_size=*/std::nullopt,
/*is_vnni=*/true);