Implementation:Vllm project Vllm SGL MoE FP8
| Knowledge Sources | |
|---|---|
| Domains | Quantization, Mixture of Experts |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements FP8 weight (w8a16) fused Mixture-of-Experts kernels for CPU inference, converting FP8 (E4M3) weights to BF16 on-the-fly during expert GEMM computation.
Description
This file extends the MoE computation pipeline to support FP8-quantized expert weights. The main kernel fused_experts_fp8_kernel_impl unpacks FP8 expert weights using cvt_e4m3_bf16_intrinsic before performing AMX-accelerated GEMM operations, significantly reducing memory footprint while maintaining computational throughput. A companion shared_expert_fp8_kernel_impl handles shared expert computation with FP8 weights and fused output accumulation.
Usage
This code is compiled as part of the vLLM CPU SGL-kernels extension. It is invoked when running Mixture-of-Experts models with FP8-quantized weights (use_fp8_w8a16=True) on CPU backends with AVX512 support.
Code Reference
Source Location
- Repository: vllm
- File: csrc/cpu/sgl-kernels/moe_fp8.cpp
- Lines: 1-502
Signature
template <typename scalar_t>
void fused_experts_fp8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic0,
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ ic2,
scalar_t* __restrict__ A_tmp,
scalar_t* __restrict__ B_tmp,
float* __restrict__ C_tmp,
const scalar_t* __restrict__ input,
const at::Float8_e4m3fn* __restrict__ packed_w1,
const at::Float8_e4m3fn* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
int64_t block_size_N,
int64_t block_size_K,
const float* __restrict__ topk_weights,
const int32_t* __restrict__ sorted_ids,
const int32_t* __restrict__ expert_ids,
const int32_t* __restrict__ offsets,
int64_t M, int64_t N, int64_t K,
int64_t E, int64_t topk,
int64_t num_tokens_post_pad);
template <typename scalar_t>
void shared_expert_fp8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic0,
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ B_tmp,
float* __restrict__ C_tmp,
const scalar_t* __restrict__ input,
const at::Float8_e4m3fn* __restrict__ packed_w1,
const at::Float8_e4m3fn* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
int64_t block_size_N,
int64_t block_size_K,
const scalar_t* __restrict__ fused_experts_out,
float routed_scaling_factor,
int64_t M, int64_t N, int64_t K);
Import
#include "common.h"
#include "gemm.h"
#include "vec.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| output | scalar_t* | Yes | Pre-allocated output buffer for expert results |
| input | const scalar_t* | Yes | Hidden states input tensor of shape [M, K] |
| packed_w1 | const Float8_e4m3fn* | Yes | FP8-quantized gate/up weights for all experts |
| packed_w2 | const Float8_e4m3fn* | Yes | FP8-quantized down projection weights for all experts |
| w1s | const float* | Yes | Per-block dequantization scales for w1 |
| w2s | const float* | Yes | Per-block dequantization scales for w2 |
| block_size_N | int64_t | Yes | Block size along N dimension for quantization |
| block_size_K | int64_t | Yes | Block size along K dimension for quantization |
| topk_weights | const float* | Yes | Gating weights for top-k expert selection |
| sorted_ids | const int32_t* | Yes | Token-to-expert mapping sorted by expert |
| expert_ids | const int32_t* | Yes | Expert IDs for each sorted token group |
| offsets | const int32_t* | Yes | Offsets into sorted_ids for each expert group |
| M | int64_t | Yes | Number of tokens |
| N | int64_t | Yes | Intermediate dimension (half of gate+up) |
| K | int64_t | Yes | Hidden dimension |
| E | int64_t | Yes | Number of experts |
| topk | int64_t | Yes | Number of experts selected per token |
Outputs
| Name | Type | Description |
|---|---|---|
| output | scalar_t* | Weighted sum of expert outputs, shape [M, K] |
Usage Examples
// Invoked internally through the fused_experts_cpu dispatch path
// with use_fp8_w8a16=true.
// Instantiated for BFloat16:
fused_experts_fp8_kernel_impl<at::BFloat16>(
output_ptr, ic0_ptr, ic1_ptr, ic2_ptr,
A_tmp_ptr, B_tmp_ptr, C_tmp_ptr,
input_ptr, packed_w1_ptr, packed_w2_ptr,
w1_scales_ptr, w2_scales_ptr,
block_size_N, block_size_K,
topk_weights_ptr, sorted_ids_ptr,
expert_ids_ptr, offsets_ptr,
M, N, K, E, topk, num_tokens_post_pad);