Implementation:Sgl project Sglang CPU MoE FP8
| Knowledge Sources | |
|---|---|
| Domains | Machine Learning, CPU Kernels, Quantization |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Implements the FP8 (Float8_e4m3fn) weight-quantized variant of the fused Mixture-of-Experts kernel for CPU inference, combining FP8 weight dequantization with the standard two-GEMM MoE computation pipeline.
Description
This file follows the same fused MoE pattern as the BFloat16 variant but with FP8 expert weights (W8A16 quantization scheme). The expert weights are stored in at::Float8_e4m3fn format and dequantized to BFloat16 on-the-fly during the GEMM operations, reusing the FP8 dequantization infrastructure from gemm_fp8.cpp.
Key helper functions include:
- copy_stub -- vectorized memory copy with SIMD unrolling
- copy_mul_stub -- copies data while multiplying by a topk weight, with float-to-scalar conversion
- sum_stub -- accumulates across top-k experts for a single token, reducing [topk, K] to [K]
- add_mul_stub -- computes out = input + input2 * scale, useful for weighted expert output accumulation in the shared expert path
The kernel maintains the same two-GEMM-with-fused-SiLU structure, applying per-block scales during the dequantization step based on block_size_N and block_size_K parameters. Template instantiations are provided for both at::BFloat16 and at::Half scalar types.
Usage
Use this kernel for CPU inference of MoE models when expert weights have been quantized to FP8 format. It provides approximately 2x memory reduction for expert weights compared to BFloat16, which is especially impactful for MoE models where expert weights constitute the majority of model parameters.
Code Reference
Source Location
- Repository: Sgl_project_Sglang
- File: sgl-kernel/csrc/cpu/moe_fp8.cpp
- Lines: 1-491
Signature
template <typename scalar_t>
void fused_experts_fp8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic0,
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ B_tmp,
float* __restrict__ C_tmp,
const scalar_t* __restrict__ input,
const at::Float8_e4m3fn* __restrict__ packed_w1,
const at::Float8_e4m3fn* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
int64_t block_size_N,
int64_t block_size_K,
const float* __restrict__ topk_weights,
const int32_t* __restrict__ sorted_ids,
const int32_t* __restrict__ expert_ids,
const int32_t* __restrict__ offsets,
int64_t M, int64_t N, int64_t K,
int64_t E, int64_t topk,
int64_t num_tokens_post_pad);
template <typename scalar_t>
void shared_expert_fp8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic0,
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ B_tmp,
float* __restrict__ C_tmp,
const scalar_t* __restrict__ input,
const at::Float8_e4m3fn* __restrict__ packed_w1,
const at::Float8_e4m3fn* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
int64_t block_size_N,
int64_t block_size_K,
const scalar_t* __restrict__ fused_experts_out,
float routed_scaling_factor,
int64_t M, int64_t N, int64_t K);
Import
#include "common.h"
#include "gemm.h"
#include "vec.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input | scalar_t* [M, K] | Yes | Input hidden states (BFloat16 or Half) |
| packed_w1 | Float8_e4m3fn* [E, 2N, K] | Yes | Gate+up projection weights in FP8 format |
| packed_w2 | Float8_e4m3fn* [E, K, N] | Yes | Down projection weights in FP8 format |
| w1s | float* | Yes | Per-block dequantization scales for w1 |
| w2s | float* | Yes | Per-block dequantization scales for w2 |
| block_size_N | int64_t | Yes | Block size along N dimension for block-wise quantization |
| block_size_K | int64_t | Yes | Block size along K dimension for block-wise quantization |
| topk_weights | float* [M, topk] | Yes | Routing weights for selected experts per token |
| sorted_ids | int32_t* | Yes | Token indices sorted by assigned expert |
| expert_ids | int32_t* | Yes | Expert assignment for each sorted block |
| offsets | int32_t* | Yes | Starting offsets for each M block in sorted order |
| M, N, K | int64_t | Yes | Matrix dimensions: tokens, intermediate size, hidden size |
| E | int64_t | Yes | Number of experts |
| topk | int64_t | Yes | Number of experts selected per token |
Outputs
| Name | Type | Description |
|---|---|---|
| output | scalar_t* [M, K] | MoE output after FP8-dequantized expert computation and weighted accumulation |
Usage Examples
// FP8 MoE is dispatched via the main fused_experts_cpu() entry point
// by setting moe_comp_method to the FP8 enum value and providing FP8 weights
at::Tensor output = fused_experts_cpu(
hidden_states, // [M, K] BFloat16
w1_fp8, // [E, 2N, K] Float8_e4m3fn
w2_fp8, // [E, K, N] Float8_e4m3fn
topk_weights, // [M, topk] float32
topk_ids, // [M, topk] int32
/*inplace=*/false,
/*moe_comp_method=*/CPUQuantMethod::FP8,
w1_scale, // per-block scales
w2_scale,
/*w1_zero=*/std::nullopt,
/*w2_zero=*/std::nullopt,
block_size, // {block_size_N, block_size_K}
/*is_vnni=*/true);