Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Vllm project Vllm CPU Attn Impl

From Leeroopedia


Knowledge Sources
Domains Attention, CPU_Inference, Scheduling
Last Updated 2026-02-08 00:00 GMT

Overview

Defines the foundational interface, work scheduling structures, and memory management classes for CPU-based attention computation across multiple ISA backends.

Description

This header provides the AttentionImpl template interface parameterized by ISA (AMX, VEC, VEC16, NEON), along with AttentionWorkItemGroup and ReductionWorkItemGroup for managing parallel work distribution of Q/KV splits across threads. The AttentionMetadata struct coordinates thread scheduling with atomic counters, prefix-sum work item distribution, and ISA selection. The AttentionScratchPad class manages per-thread and per-KV-head scratchpad memory for Q buffers, logits, partial outputs, and reduction buffers. Additional classes include AttentionScheduler for work partitioning, AttentionMainLoop for the core attention loop, and the top-level Attention class that orchestrates the full pipeline.

Usage

This header is the central include for all CPU attention implementations. It is compiled on both x86 and ARM platforms and serves as the base for ISA-specific specializations in cpu_attn_amx.hpp, cpu_attn_neon.hpp, and cpu_attn_neon_bfmmla.hpp.

Code Reference

Source Location

Signature

namespace cpu_attention {

enum class ISA { AMX, VEC, VEC16, NEON };

template <ISA isa, typename scalar_t, int64_t head_dim>
class AttentionImpl {};

struct AttentionWorkItemGroup {
  int32_t req_id;
  int32_t q_token_id_start;
  int32_t q_token_num;
  int32_t kv_split_pos_start;
  int32_t kv_split_pos_end;
  int64_t total_kv_len;
  int32_t split_id;
  int32_t local_split_id;
};

struct ReductionWorkItemGroup {
  int32_t req_id;
  int32_t q_token_id_start;
  int32_t q_token_id_num;
  int32_t split_start_id;
  int32_t split_num;
};

struct AttentionMetadata {
  std::atomic_int64_t counter;
  ISA isa;
  int32_t workitem_group_num;
  int32_t reduction_item_num;
  int32_t thread_num;
  AttentionWorkItemGroup* workitem_groups_ptr;
  ReductionWorkItemGroup* reduction_items_ptr;
  int32_t cu_workitem_num_per_thread[1025];
};

class AttentionScratchPad { ... };
class AttentionScheduler { ... };
class AttentionMainLoop { ... };
class Attention { ... };

} // namespace cpu_attention

Import

#include "cpu_attn_impl.hpp"

I/O Contract

Inputs

Name Type Required Description
isa ISA Yes Instruction set architecture to use (AMX, VEC, VEC16, NEON)
workitem_group_num int32_t Yes Total number of attention work item groups to schedule
reduction_item_num int32_t Yes Number of reduction work items for split-KV merging
reduction_split_num int32_t Yes Total number of KV splits across all requests
split_kv_q_token_num_threshold int32_t Yes Threshold for enabling split-KV based on query token count
scratchpad_ptr void* Yes Pointer to pre-allocated scratchpad memory for thread-local and reduction buffers

Outputs

Name Type Description
output_buffer float* Final attention output after main loop and optional reduction
max_buffer float* Per-head softmax max values from partial computations
sum_buffer float* Per-head softmax sum values from partial computations

Usage Examples

#include "cpu_attn_impl.hpp"

// Create attention metadata for scheduling
auto* metadata = new (aligned_buffer) cpu_attention::AttentionMetadata(
    cpu_attention::ISA::AMX,
    workitem_count,
    reduction_count,
    split_num,
    q_token_threshold
);

// Each thread accesses its scratchpad
int thread_id = omp_get_thread_num();
cpu_attention::AttentionScratchPad pad(thread_id, *metadata, scratchpad_ptr);

// Update scratchpad offsets for the current work item
pad.update(head_dim, q_buf_elem_size, logits_elem_size,
           output_elem_size, max_q_per_iter, q_tile_size, kv_tile_size);

// Access buffers
float* logits = pad.get_logits_buffer();
float* output = pad.get_output_buffer();

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment