Implementation:Sgl project Sglang CPU QKV Projection
| Knowledge Sources | |
|---|---|
| Domains | Machine Learning, CPU Kernels |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Implements a fused QKV projection kernel with weight absorption and RoPE (Rotary Position Embedding) for DeepSeek-style MLA (Multi-head Latent Attention) architectures on CPU.
Description
This kernel fuses multiple operations that would otherwise require separate GEMM calls and memory passes for the DeepSeek MLA attention mechanism:
- q_a_proj and kv_a_proj_with_mqa are fused into a single segmented GEMM via segment_gemm_kernel_impl, which computes [C0, C1] = A @ [B0, B1] where B0 and B1 are different weight matrices. The segmented GEMM parallelizes across [MB, NB0 + NB1] blocks, dispatching tiles to either B0 or B1 based on the block index.
- Subsequent q_a_layernorm and kv_a_layernorm are fused into a single parallel loop via rms_norm_kernel_impl.
- rotary_emb_kernel_impl applies rotary position embedding to the key PE component.
Three variants of segment_gemm_kernel_impl are provided:
- BFloat16/Half using tinygemm_kernel with optional brgemm for large M.
- INT8 W8A8 using int8_scaled_mm_with_quant infrastructure with separate activation and weight scales.
- FP8 W8A16 using fp8_scaled_mm_cpu with per-block scales.
The public API function qkv_proj_with_rope accepts hidden states, all projection and normalization weights, position indices, and cos/sin cache, returning a tuple of (query, key, value) tensors. It expects weights to be pre-packed in VNNI format.
Usage
Use this kernel for DeepSeek-V2/V3 model inference on CPU. MLA uses compressed latent representations for keys and values, requiring projection through absorption matrices before attention. By fusing Q and KV projections into a single GEMM, this kernel avoids splitting the input channel dimension across two separate operations, significantly reducing memory bandwidth usage.
Code Reference
Source Location
- Repository: Sgl_project_Sglang
- File: sgl-kernel/csrc/cpu/qkv_proj.cpp
- Lines: 1-701
Signature
std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
at::Tensor& hidden_states,
at::Tensor& q_a_proj_weight,
at::Tensor& q_b_proj_weight,
at::Tensor& kv_a_proj_weight,
at::Tensor& w_kc,
at::Tensor& q_a_layernorm_weight,
at::Tensor& kv_a_layernorm_weight,
at::Tensor& positions,
at::Tensor& cos_sin_cache,
double eps,
bool use_int8_w8a8,
bool use_fp8_w8a16,
std::optional<at::Tensor> q_a_proj_scale,
std::optional<at::Tensor> q_b_proj_scale,
std::optional<at::Tensor> kv_a_proj_scale,
bool is_vnni,
std::optional<std::vector<int64_t>> block_size);
// Internal: segmented GEMM computing [C0, C1] = A @ [B0, B1]
template <typename scalar_t>
void segment_gemm_kernel_impl(
scalar_t* __restrict__ C0,
scalar_t* __restrict__ C1,
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B0,
const scalar_t* __restrict__ B1,
int64_t M, int64_t N0, int64_t N1, int64_t K);
Import
#include "common.h"
#include "gemm.h"
#include "vec.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| hidden_states | at::Tensor [num_seqs, hidden_size] | Yes | Input hidden states (e.g., [1, 7168] for DeepSeek R1) |
| q_a_proj_weight | at::Tensor [q_lora_rank, hidden_size] | Yes | First-stage query projection weight (e.g., [1536, 7168]) |
| q_b_proj_weight | at::Tensor [num_heads * qk_head_dim, q_lora_rank] | Yes | Second-stage query projection weight (e.g., [4224, 1536]) |
| kv_a_proj_weight | at::Tensor [kv_lora_rank + qk_rope_head_dim, hidden_size] | Yes | KV projection weight (e.g., [576, 7168]) |
| w_kc | at::Tensor [num_heads, kv_lora_rank, qk_nope_head_dim] | Yes | Key absorption weight (e.g., [22, 512, 128]) |
| q_a_layernorm_weight | at::Tensor [q_lora_rank] | Yes | RMSNorm weight for query latent (e.g., [1536]) |
| kv_a_layernorm_weight | at::Tensor [kv_lora_rank] | Yes | RMSNorm weight for KV latent (e.g., [512]) |
| positions | at::Tensor [num_seqs] | Yes | Token position indices for RoPE (int64) |
| cos_sin_cache | at::Tensor [max_pos, rotary_dim] | Yes | Precomputed cos/sin values for RoPE |
| eps | double | Yes | Epsilon for layer normalization |
| use_int8_w8a8 | bool | Yes | Whether to use INT8 quantized weights |
| use_fp8_w8a16 | bool | Yes | Whether to use FP8 quantized weights |
| is_vnni | bool | Yes | Whether weights are pre-packed in VNNI format (must be true) |
Outputs
| Name | Type | Description |
|---|---|---|
| query | at::Tensor [num_seqs, num_heads, qk_head_dim] | Projected and RoPE-embedded query tensor |
| key | at::Tensor [num_seqs, 1, kv_lora_rank + qk_rope_head_dim] | Projected key tensor with RoPE on the PE component |
| value | at::Tensor [num_seqs, 1, kv_lora_rank] | Projected value tensor (shares storage with key's non-PE part) |
Usage Examples
// DeepSeek MLA fused QKV projection with RoPE
auto [query, key, value] = qkv_proj_with_rope(
hidden_states, // [num_seqs, 7168]
q_a_proj_weight, // [1536, 7168]
q_b_proj_weight, // [4224, 1536]
kv_a_proj_weight, // [576, 7168]
w_kc, // [22, 512, 128]
q_a_layernorm_weight, // [1536]
kv_a_layernorm_weight, // [512]
positions, // [num_seqs]
cos_sin_cache, // [max_pos, rotary_dim]
/*eps=*/1e-5,
/*use_int8_w8a8=*/false,
/*use_fp8_w8a16=*/false,
/*q_a_proj_scale=*/std::nullopt,
/*q_b_proj_scale=*/std::nullopt,
/*kv_a_proj_scale=*/std::nullopt,
/*is_vnni=*/true,
/*block_size=*/std::nullopt);