Implementation:Vllm project Vllm CPU Attn NEON BFMMLA
| Knowledge Sources | |
|---|---|
| Domains | Attention, CPU_Inference, SIMD, ARM, BFloat16 |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements ARM BFMMLA (BFloat16 Matrix Multiply-Accumulate) instruction-based attention kernels for accelerated BF16 attention computation on ARMv8.6+ CPUs.
Description
This header provides tile-based GEMM micro-kernels using the vbfmmlaq_f32 intrinsic with a 2x4x2 tile pattern (2 rows, 4 K reduction elements, 2 column-pairs yielding 8 output columns per block). The reshape_Q_2xK_for_bfmmla function packs two rows of Q data into the interleaved BF16 format required by BFMMLA instructions, with zero-padding for K tails. The gemm_rowpairs_x8_bfmmla_neon template micro-kernel processes 1, 2, or 4 row-pairs at a time with compile-time unrolling, supporting both QK (token-column B layout) and PV (token-row B layout) attention phases. Accumulator management uses 2x2 load/store helpers with compile-time row count specialization.
Usage
This header is compiled on ARM platforms with BF16 hardware support (ARMv8.6+ such as Neoverse V2/N2). It is conditionally included by cpu_attn_neon.hpp when the ARM_BF16_SUPPORT macro is defined and provides the TileGemmNEONBFMMLA and AttentionImplNEONBFMMLA classes for the BFMMLA-accelerated attention path.
Code Reference
Source Location
- Repository: vllm
- File: csrc/cpu/cpu_attn_neon_bfmmla.hpp
- Lines: 1-682
Signature
namespace cpu_attention {
// BFMMLA tile constants
constexpr int32_t TILE_ROWS = 2; // M dimension
constexpr int32_t TILE_K = 4; // K reduction
constexpr int32_t TILE_COLS = 2; // N column-pair
constexpr int32_t OUTPUT_COLS_PER_BLOCK = 8;
// Pack Q rows into BFMMLA-friendly interleaved layout
FORCE_INLINE void reshape_Q_2xK_for_bfmmla(
const c10::BFloat16* r0, const c10::BFloat16* r1,
c10::BFloat16* dst, int32_t K);
// 2x2 accumulator load/store
template <int32_t m_rows>
FORCE_INLINE float32x4_t load_acc_2x2(float* base, int64_t ldc, int col_off);
template <int32_t m_rows>
FORCE_INLINE void store_acc_2x2(float32x4_t acc, float* base, int64_t ldc, int col_off);
// Micro-kernel: RP row-pairs x 8 columns using BFMMLA
template <int32_t RP, int32_t K_static, AttentionGemmPhase phase>
FORCE_INLINE void gemm_rowpairs_x8_bfmmla_neon(
const bfloat16_t* const* A_packed_rp,
const int32_t* m_rows_rp,
const bfloat16_t* B_blk,
float* C, int64_t ldc, bool accumulate,
int64_t b_stride, int32_t K_runtime = 0);
class TileGemmNEONBFMMLA { ... };
class AttentionImplNEONBFMMLA { ... };
} // namespace cpu_attention
Import
#include "cpu_attn_neon_bfmmla.hpp"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| A_packed_rp | const bfloat16_t* const* |
Yes | Array of pointers to pre-packed A matrix row-pairs in BFMMLA interleaved format |
| m_rows_rp | const int32_t* |
Yes | Array of active row counts (1 or 2) for each row-pair |
| B_blk | const bfloat16_t* |
Yes | Pointer to B matrix block (K or V cache); layout depends on attention phase |
| C | float* |
Yes | Pointer to output accumulator matrix; row-major float32 |
| ldc | int64_t |
Yes | Leading dimension of C matrix |
| accumulate | bool |
Yes | Whether to load and accumulate into C or zero-initialize accumulators |
| b_stride | int64_t |
Yes | Stride between B columns; depends on cache block layout |
| K_static | int32_t (template) |
Yes | Compile-time K dimension; set to negative for runtime K (PV phase only) |
| phase | AttentionGemmPhase (template) |
Yes | QK or PV to select B matrix access pattern |
Outputs
| Name | Type | Description |
|---|---|---|
| C | float* |
Updated output matrix with BFMMLA-computed GEMM results (attention scores or weighted values) |
Usage Examples
#include "cpu_attn_neon_bfmmla.hpp"
// Pack Q rows for BFMMLA format
c10::BFloat16 packed_q[2 * round_up(head_dim, TILE_K)];
reshape_Q_2xK_for_bfmmla(q_row0, q_row1, packed_q, head_dim);
// Compute QK attention scores using BFMMLA
const bfloat16_t* a_ptrs[1] = { reinterpret_cast<const bfloat16_t*>(packed_q) };
int32_t m_rows[1] = { 2 };
gemm_rowpairs_x8_bfmmla_neon<1, HEAD_DIM, AttentionGemmPhase::QK>(
a_ptrs, m_rows,
reinterpret_cast<const bfloat16_t*>(k_cache_block),
logits_output, ldc, /*accumulate=*/false, b_stride);
// Compute PV with runtime K dimension
gemm_rowpairs_x8_bfmmla_neon<1, -1, AttentionGemmPhase::PV>(
a_ptrs, m_rows,
reinterpret_cast<const bfloat16_t*>(v_cache_block),
output, ldc, /*accumulate=*/true, b_stride, seq_tile_len);