Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Vllm project Vllm CPU MLA Decode

From Leeroopedia


Knowledge Sources
Domains CPU_Inference, MLA, Attention
Last Updated 2026-02-08 00:00 GMT

Overview

Implements the Multi-Latent Attention (MLA) decode-phase kernel for CPU, computing paged attention over compressed KV cache blocks.

Description

This file provides the CPU implementation of MLA decoding as used by DeepSeek-V3 style models. The core template mla_decode_block_head performs fused Q*K^T dot products with online softmax and attention-weighted V accumulation using vectorized operations (AVX-512 BF16, FP16, FP32). It processes two heads simultaneously for improved instruction-level parallelism and reuses FP32-converted KV cache blocks across query heads.

The outer function mla_decode_kvcache dispatches across sequences using OpenMP parallelism, with each thread processing blocks independently and merging partial attention states via the log-sum-exp trick (following Section 2.2 of the FlashInfer paper, arXiv:2501.01005).

Usage

This kernel is compiled as part of the vLLM CPU extension and is invoked during the decode phase of MLA-based models when running on CPU. It is specifically optimized for head_dim=576, v_head_dim=512, and block_size=16 (the DeepSeek-V3 configuration).

Code Reference

Source Location

Signature

void mla_decode_kvcache(
    torch::Tensor& out,
    torch::Tensor& query,
    torch::Tensor& kv_cache,
    double scale,
    torch::Tensor& block_tables,
    torch::Tensor& seq_lens);

template <typename scalar_t, int HEAD_DIM, int V_HEAD_DIM, int BLOCK_SIZE>
void mla_decode_kvcache_cpu_impl(
    scalar_t* out,             // [num_seqs, num_heads, v_head_dim]
    const scalar_t* q,         // [num_seqs, num_heads, head_dim]
    const scalar_t* kv_cache,  // [num_blocks, block_size, head_dim]
    const int num_heads, const float scale,
    const int* block_tables,   // [num_seqs, max_num_blocks_per_seq]
    const int* seq_lens,       // [num_seqs]
    const int max_num_blocks_per_seq,
    const int o_stride, const int q_stride,
    const int kv_stride, const int num_seqs);

template <int HEAD_DIM, int V_HEAD_DIM, int BLOCK_SIZE, int HEAD_UNROLL,
          typename qk_vec_type>
void mla_decode_block_head(
    const qk_vec_type* q_vecs,
    const qk_vec_type* k_vecs,
    const vec_op::FP32Vec16* v_vecs_f32,
    float* acc_out, float* acc_lse,
    const float scale, const int num_tokens);

Import

#include "cpu_types.hpp"

I/O Contract

Inputs

Name Type Required Description
query torch::Tensor [num_seqs, num_heads, head_dim] Yes Query vectors for each sequence and attention head
kv_cache torch::Tensor [num_blocks, block_size, head_dim] Yes Paged compressed KV cache blocks containing latent key-value representations
scale double Yes Attention scaling factor, typically 1/sqrt(head_dim)
block_tables torch::Tensor [num_seqs, max_num_blocks_per_seq] Yes Mapping from logical to physical block indices per sequence
seq_lens torch::Tensor [num_seqs] Yes Actual sequence length for each sequence (determines valid blocks)

Outputs

Name Type Description
out torch::Tensor [num_seqs, num_heads, v_head_dim] Attention output for each sequence and head with dimension v_head_dim (512)

Usage Examples

// MLA decode for DeepSeek-V3 on CPU
torch::Tensor out = torch::empty({num_seqs, num_heads, 512}, query.options());
mla_decode_kvcache(
    out,
    query,        // [num_seqs, num_heads, 576]
    kv_cache,     // [num_blocks, 16, 576]
    1.0 / std::sqrt(576.0),
    block_tables, // [num_seqs, max_blocks]
    seq_lens);    // [num_seqs]

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment