Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Sgl project Sglang CPU QKV Projection

From Leeroopedia


Knowledge Sources
Domains Machine Learning, CPU Kernels
Last Updated 2026-02-10 00:00 GMT

Overview

Implements a fused QKV projection kernel with weight absorption and RoPE (Rotary Position Embedding) for DeepSeek-style MLA (Multi-head Latent Attention) architectures on CPU.

Description

This kernel fuses multiple operations that would otherwise require separate GEMM calls and memory passes for the DeepSeek MLA attention mechanism:

  • q_a_proj and kv_a_proj_with_mqa are fused into a single segmented GEMM via segment_gemm_kernel_impl, which computes [C0, C1] = A @ [B0, B1] where B0 and B1 are different weight matrices. The segmented GEMM parallelizes across [MB, NB0 + NB1] blocks, dispatching tiles to either B0 or B1 based on the block index.
  • Subsequent q_a_layernorm and kv_a_layernorm are fused into a single parallel loop via rms_norm_kernel_impl.
  • rotary_emb_kernel_impl applies rotary position embedding to the key PE component.

Three variants of segment_gemm_kernel_impl are provided:

  • BFloat16/Half using tinygemm_kernel with optional brgemm for large M.
  • INT8 W8A8 using int8_scaled_mm_with_quant infrastructure with separate activation and weight scales.
  • FP8 W8A16 using fp8_scaled_mm_cpu with per-block scales.

The public API function qkv_proj_with_rope accepts hidden states, all projection and normalization weights, position indices, and cos/sin cache, returning a tuple of (query, key, value) tensors. It expects weights to be pre-packed in VNNI format.

Usage

Use this kernel for DeepSeek-V2/V3 model inference on CPU. MLA uses compressed latent representations for keys and values, requiring projection through absorption matrices before attention. By fusing Q and KV projections into a single GEMM, this kernel avoids splitting the input channel dimension across two separate operations, significantly reducing memory bandwidth usage.

Code Reference

Source Location

Signature

std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
    at::Tensor& hidden_states,
    at::Tensor& q_a_proj_weight,
    at::Tensor& q_b_proj_weight,
    at::Tensor& kv_a_proj_weight,
    at::Tensor& w_kc,
    at::Tensor& q_a_layernorm_weight,
    at::Tensor& kv_a_layernorm_weight,
    at::Tensor& positions,
    at::Tensor& cos_sin_cache,
    double eps,
    bool use_int8_w8a8,
    bool use_fp8_w8a16,
    std::optional<at::Tensor> q_a_proj_scale,
    std::optional<at::Tensor> q_b_proj_scale,
    std::optional<at::Tensor> kv_a_proj_scale,
    bool is_vnni,
    std::optional<std::vector<int64_t>> block_size);

// Internal: segmented GEMM computing [C0, C1] = A @ [B0, B1]
template <typename scalar_t>
void segment_gemm_kernel_impl(
    scalar_t* __restrict__ C0,
    scalar_t* __restrict__ C1,
    const scalar_t* __restrict__ A,
    const scalar_t* __restrict__ B0,
    const scalar_t* __restrict__ B1,
    int64_t M, int64_t N0, int64_t N1, int64_t K);

Import

#include "common.h"
#include "gemm.h"
#include "vec.h"

I/O Contract

Inputs

Name Type Required Description
hidden_states at::Tensor [num_seqs, hidden_size] Yes Input hidden states (e.g., [1, 7168] for DeepSeek R1)
q_a_proj_weight at::Tensor [q_lora_rank, hidden_size] Yes First-stage query projection weight (e.g., [1536, 7168])
q_b_proj_weight at::Tensor [num_heads * qk_head_dim, q_lora_rank] Yes Second-stage query projection weight (e.g., [4224, 1536])
kv_a_proj_weight at::Tensor [kv_lora_rank + qk_rope_head_dim, hidden_size] Yes KV projection weight (e.g., [576, 7168])
w_kc at::Tensor [num_heads, kv_lora_rank, qk_nope_head_dim] Yes Key absorption weight (e.g., [22, 512, 128])
q_a_layernorm_weight at::Tensor [q_lora_rank] Yes RMSNorm weight for query latent (e.g., [1536])
kv_a_layernorm_weight at::Tensor [kv_lora_rank] Yes RMSNorm weight for KV latent (e.g., [512])
positions at::Tensor [num_seqs] Yes Token position indices for RoPE (int64)
cos_sin_cache at::Tensor [max_pos, rotary_dim] Yes Precomputed cos/sin values for RoPE
eps double Yes Epsilon for layer normalization
use_int8_w8a8 bool Yes Whether to use INT8 quantized weights
use_fp8_w8a16 bool Yes Whether to use FP8 quantized weights
is_vnni bool Yes Whether weights are pre-packed in VNNI format (must be true)

Outputs

Name Type Description
query at::Tensor [num_seqs, num_heads, qk_head_dim] Projected and RoPE-embedded query tensor
key at::Tensor [num_seqs, 1, kv_lora_rank + qk_rope_head_dim] Projected key tensor with RoPE on the PE component
value at::Tensor [num_seqs, 1, kv_lora_rank] Projected value tensor (shares storage with key's non-PE part)

Usage Examples

// DeepSeek MLA fused QKV projection with RoPE
auto [query, key, value] = qkv_proj_with_rope(
    hidden_states,           // [num_seqs, 7168]
    q_a_proj_weight,         // [1536, 7168]
    q_b_proj_weight,         // [4224, 1536]
    kv_a_proj_weight,        // [576, 7168]
    w_kc,                    // [22, 512, 128]
    q_a_layernorm_weight,    // [1536]
    kv_a_layernorm_weight,   // [512]
    positions,               // [num_seqs]
    cos_sin_cache,           // [max_pos, rotary_dim]
    /*eps=*/1e-5,
    /*use_int8_w8a8=*/false,
    /*use_fp8_w8a16=*/false,
    /*q_a_proj_scale=*/std::nullopt,
    /*q_b_proj_scale=*/std::nullopt,
    /*kv_a_proj_scale=*/std::nullopt,
    /*is_vnni=*/true,
    /*block_size=*/std::nullopt);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment