Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Sgl project Sglang CPU MoE INT4

From Leeroopedia


Knowledge Sources
Domains Machine Learning, CPU Kernels, Quantization
Last Updated 2026-02-10 00:00 GMT

Overview

Implements the INT4 weight-quantized (W4A8) variant of the fused Mixture-of-Experts kernel for CPU inference, combining 4-bit weight dequantization with dynamically quantized 8-bit activations.

Description

This file follows the fused MoE two-GEMM pattern with INT4 expert weights packed as uint8 storage (2 values per byte). The INT4 weights are unpacked, dequantized with per-group zero points and scales, and used with dynamically quantized uint8 activations via the INT4 GEMM infrastructure from gemm_int4.cpp. The computation leverages AVX-512 VNNI int8 dot product instructions after weight dequantization.

Key helper functions provided:

  • copy_stub -- two overloads: same-type copy and float-to-scalar conversion copy
  • copy_mul_stub -- copies with topk weight scaling (float input to scalar output)
  • sum_stub -- reduces [topk, K] to [K] by accumulating across top-k experts
  • add_mul_stub -- computes out = input + input2 * scale for residual connections

The kernel processes three stages: (1) align and sort tokens by expert assignment, (2) execute per-expert fused GEMM1-SiLU-GEMM2 with INT4 weight dequantization, and (3) accumulate weighted expert outputs. Template instantiations are provided for at::BFloat16 and at::Half.

Usage

Use this kernel when deploying large MoE models on memory-constrained CPU systems. INT4 quantization provides approximately 4x memory reduction per expert weight compared to BFloat16, which is critical for fitting models like Mixtral-8x7B or DeepSeek-V3 in available memory. The W4A8 scheme balances compression ratio with inference quality.

Code Reference

Source Location

Signature

template <typename scalar_t>
void fused_experts_int4_w4a8_kernel_impl(
    scalar_t* __restrict__ output,
    scalar_t* __restrict__ ic0,
    scalar_t* __restrict__ ic1,
    scalar_t* __restrict__ ic2,
    uint8_t* __restrict__ A_tmp,
    uint8_t* __restrict__ Aq_tmp,
    float* __restrict__ As_tmp,
    int32_t* __restrict__ Azp_tmp,
    float* __restrict__ C_tmp,
    int8_t* __restrict__ dqB_tmp,
    const scalar_t* __restrict__ input,
    const uint8_t* __restrict__ packed_w1,
    const uint8_t* __restrict__ packed_w2,
    const int8_t* __restrict__ w1z,
    const int8_t* __restrict__ w2z,
    const float* __restrict__ w1s,
    const float* __restrict__ w2s,
    int group_size,
    const float* __restrict__ topk_weights,
    const int32_t* __restrict__ sorted_ids,
    const int32_t* __restrict__ expert_ids,
    const int32_t* __restrict__ offsets,
    int64_t M, int64_t N, int64_t K,
    int64_t E, int64_t topk,
    int64_t num_tokens_post_pad);

Import

#include "common.h"
#include "gemm.h"
#include "vec.h"

I/O Contract

Inputs

Name Type Required Description
input scalar_t* [M, K] Yes Input hidden states (BFloat16 or Half)
packed_w1 uint8_t* [E, 2N, K/2] Yes Gate+up projection weights in packed INT4 format (2 values per byte)
packed_w2 uint8_t* [E, K, N/2] Yes Down projection weights in packed INT4 format
w1z int8_t* Yes Per-group zero points for w1 dequantization
w2z int8_t* Yes Per-group zero points for w2 dequantization
w1s float* Yes Per-group dequantization scales for w1
w2s float* Yes Per-group dequantization scales for w2
group_size int Yes Number of elements per quantization group
topk_weights float* [M, topk] Yes Routing weights for selected experts
sorted_ids int32_t* Yes Token indices sorted by expert assignment
expert_ids int32_t* Yes Expert assignment per sorted block
offsets int32_t* Yes Starting offsets for each M block
M, N, K int64_t Yes Matrix dimensions: tokens, intermediate size, hidden size
E int64_t Yes Number of experts
topk int64_t Yes Number of selected experts per token

Outputs

Name Type Description
output scalar_t* [M, K] MoE output after INT4-dequantized expert computation and accumulation

Usage Examples

// INT4 MoE is dispatched via the main fused_experts_cpu() entry point
at::Tensor output = fused_experts_cpu(
    hidden_states,              // [M, K] BFloat16
    w1_int4,                    // [E, groups, 2, N/2, K] packed INT4
    w2_int4,                    // [E, groups, 2, K/2, N] packed INT4
    topk_weights,               // [M, topk] float32
    topk_ids,                   // [M, topk] int32
    /*inplace=*/false,
    /*moe_comp_method=*/CPUQuantMethod::INT4_W4A8,
    w1_scale, w2_scale,         // per-group scales
    w1_zero, w2_zero,           // per-group zero points
    /*block_size=*/std::nullopt,
    /*is_vnni=*/true);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment