Implementation:Vllm project Vllm CPU MLA Decode
| Knowledge Sources | |
|---|---|
| Domains | CPU_Inference, MLA, Attention |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements the Multi-Latent Attention (MLA) decode-phase kernel for CPU, computing paged attention over compressed KV cache blocks.
Description
This file provides the CPU implementation of MLA decoding as used by DeepSeek-V3 style models. The core template mla_decode_block_head performs fused Q*K^T dot products with online softmax and attention-weighted V accumulation using vectorized operations (AVX-512 BF16, FP16, FP32). It processes two heads simultaneously for improved instruction-level parallelism and reuses FP32-converted KV cache blocks across query heads.
The outer function mla_decode_kvcache dispatches across sequences using OpenMP parallelism, with each thread processing blocks independently and merging partial attention states via the log-sum-exp trick (following Section 2.2 of the FlashInfer paper, arXiv:2501.01005).
Usage
This kernel is compiled as part of the vLLM CPU extension and is invoked during the decode phase of MLA-based models when running on CPU. It is specifically optimized for head_dim=576, v_head_dim=512, and block_size=16 (the DeepSeek-V3 configuration).
Code Reference
Source Location
- Repository: vllm
- File: csrc/cpu/mla_decode.cpp
- Lines: 1-400
Signature
void mla_decode_kvcache(
torch::Tensor& out,
torch::Tensor& query,
torch::Tensor& kv_cache,
double scale,
torch::Tensor& block_tables,
torch::Tensor& seq_lens);
template <typename scalar_t, int HEAD_DIM, int V_HEAD_DIM, int BLOCK_SIZE>
void mla_decode_kvcache_cpu_impl(
scalar_t* out, // [num_seqs, num_heads, v_head_dim]
const scalar_t* q, // [num_seqs, num_heads, head_dim]
const scalar_t* kv_cache, // [num_blocks, block_size, head_dim]
const int num_heads, const float scale,
const int* block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const int o_stride, const int q_stride,
const int kv_stride, const int num_seqs);
template <int HEAD_DIM, int V_HEAD_DIM, int BLOCK_SIZE, int HEAD_UNROLL,
typename qk_vec_type>
void mla_decode_block_head(
const qk_vec_type* q_vecs,
const qk_vec_type* k_vecs,
const vec_op::FP32Vec16* v_vecs_f32,
float* acc_out, float* acc_lse,
const float scale, const int num_tokens);
Import
#include "cpu_types.hpp"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| query | torch::Tensor [num_seqs, num_heads, head_dim] | Yes | Query vectors for each sequence and attention head |
| kv_cache | torch::Tensor [num_blocks, block_size, head_dim] | Yes | Paged compressed KV cache blocks containing latent key-value representations |
| scale | double | Yes | Attention scaling factor, typically 1/sqrt(head_dim) |
| block_tables | torch::Tensor [num_seqs, max_num_blocks_per_seq] | Yes | Mapping from logical to physical block indices per sequence |
| seq_lens | torch::Tensor [num_seqs] | Yes | Actual sequence length for each sequence (determines valid blocks) |
Outputs
| Name | Type | Description |
|---|---|---|
| out | torch::Tensor [num_seqs, num_heads, v_head_dim] | Attention output for each sequence and head with dimension v_head_dim (512) |
Usage Examples
// MLA decode for DeepSeek-V3 on CPU
torch::Tensor out = torch::empty({num_seqs, num_heads, 512}, query.options());
mla_decode_kvcache(
out,
query, // [num_seqs, num_heads, 576]
kv_cache, // [num_blocks, 16, 576]
1.0 / std::sqrt(576.0),
block_tables, // [num_seqs, max_blocks]
seq_lens); // [num_seqs]