Implementation:Vllm project Vllm CPU Attn AMX
| Knowledge Sources | |
|---|---|
| Domains | Attention, CPU_Inference, SIMD, AMX |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements Intel AMX (Advanced Matrix Extensions) tile-based GEMM operations for the QK and PV phases of attention computation on Intel Sapphire Rapids and newer CPUs.
Description
This header provides two tile GEMM patterns -- TileGemm224 (2-2-4 pattern for 16 < m <= 32) and TileGemm122 (1-2-2 pattern for m <= 16) -- that use AMX tile instructions (_tile_dpbf16ps, _tile_loadd, _tile_stored, _tile_stream_loadd) to perform BF16 matrix multiplication on 16x64-byte tiles. The template specializations handle both QK and PV attention phases with different memory layouts: Q buffers are prepacked while logits buffers are row-major, and K/V caches are prepacked. The file also provides the AttentionImpl<ISA::AMX, ...> specialization that integrates these tile GEMM kernels into the CPU attention framework.
Usage
This header is compiled on x86 platforms with AMX support (Intel Sapphire Rapids or newer). It is included by cpu_attn_impl.hpp and activated when the runtime ISA detection selects AMX as the preferred instruction set for attention computation.
Code Reference
Source Location
- Repository: vllm
- File: csrc/cpu/cpu_attn_amx.hpp
- Lines: 1-511
Signature
namespace cpu_attention {
// AMX tile configuration
constexpr static int64_t AMX_TILE_ROW_BYTES = 64;
constexpr static int64_t AMX_TILE_ROW_NUM = 16;
constexpr static int64_t AMX_TILE_BYTES = AMX_TILE_ROW_BYTES * AMX_TILE_ROW_NUM;
// 2-2-4 tile pattern for 16 < m <= 32
template <typename kv_cache_t>
class TileGemm224 {
public:
template <AttentionGemmPhase phase, int32_t k_size>
FORCE_INLINE static void gemm(const int32_t m_size, void* a_tile,
void* b_tile, float* c_tile,
const int64_t lda, const int64_t ldb,
const int64_t ldc, const int32_t block_size,
const int32_t dynamic_k_size, const bool accum_c);
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config);
};
// 1-2-2 tile pattern for m <= 16
template <typename kv_cache_t>
class TileGemm122 { ... };
// AMX attention implementation
template <typename scalar_t, int64_t head_dim>
class AttentionImpl<ISA::AMX, scalar_t, head_dim> { ... };
} // namespace cpu_attention
Import
#include "cpu_attn_amx.hpp"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| a_tile | c10::BFloat16* |
Yes | Pointer to A matrix (Q buffer for QK phase, logits for PV phase) |
| b_tile | c10::BFloat16* |
Yes | Pointer to B matrix (K cache for QK phase, V cache for PV phase); prepacked layout |
| c_tile | float* |
Yes | Pointer to output C matrix (logits for QK, output for PV); row-major |
| m_size | int32_t |
Yes | Number of rows in A/C (number of query heads being processed) |
| lda, ldb, ldc | int64_t |
Yes | Leading dimensions of A, B, and C matrices |
| block_size | int32_t |
Yes | KV cache block size for determining B tile layout in PV phase |
| accum_c | bool |
Yes | Whether to accumulate into C (true) or zero-initialize (false) |
Outputs
| Name | Type | Description |
|---|---|---|
| c_tile | float* |
Updated output matrix with accumulated GEMM results (QK scores or PV output) |
Usage Examples
#include "cpu_attn_amx.hpp"
// Configure AMX tiles for a given M dimension
__tilecfg config;
memset(&config, 0, sizeof(config));
TileGemm224<c10::BFloat16>::init_tile_config(m_size, config);
_tile_loadconfig(&config);
// Perform QK GEMM: scores = Q @ K^T
TileGemm224<c10::BFloat16>::gemm<AttentionGemmPhase::QK, HEAD_DIM>(
m_size, q_buffer, k_cache_ptr, logits_buffer,
lda, ldb, ldc, block_size, head_dim, /*accum_c=*/false);
// Perform PV GEMM: output = softmax(scores) @ V
TileGemm224<c10::BFloat16>::gemm<AttentionGemmPhase::PV, SEQ_TILE>(
m_size, logits_buffer, v_cache_ptr, output_buffer,
lda, ldb, ldc, block_size, seq_len, /*accum_c=*/true);