Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Vllm project Vllm CPU Attn NEON BFMMLA

From Leeroopedia


Knowledge Sources
Domains Attention, CPU_Inference, SIMD, ARM, BFloat16
Last Updated 2026-02-08 00:00 GMT

Overview

Implements ARM BFMMLA (BFloat16 Matrix Multiply-Accumulate) instruction-based attention kernels for accelerated BF16 attention computation on ARMv8.6+ CPUs.

Description

This header provides tile-based GEMM micro-kernels using the vbfmmlaq_f32 intrinsic with a 2x4x2 tile pattern (2 rows, 4 K reduction elements, 2 column-pairs yielding 8 output columns per block). The reshape_Q_2xK_for_bfmmla function packs two rows of Q data into the interleaved BF16 format required by BFMMLA instructions, with zero-padding for K tails. The gemm_rowpairs_x8_bfmmla_neon template micro-kernel processes 1, 2, or 4 row-pairs at a time with compile-time unrolling, supporting both QK (token-column B layout) and PV (token-row B layout) attention phases. Accumulator management uses 2x2 load/store helpers with compile-time row count specialization.

Usage

This header is compiled on ARM platforms with BF16 hardware support (ARMv8.6+ such as Neoverse V2/N2). It is conditionally included by cpu_attn_neon.hpp when the ARM_BF16_SUPPORT macro is defined and provides the TileGemmNEONBFMMLA and AttentionImplNEONBFMMLA classes for the BFMMLA-accelerated attention path.

Code Reference

Source Location

Signature

namespace cpu_attention {

// BFMMLA tile constants
constexpr int32_t TILE_ROWS = 2;          // M dimension
constexpr int32_t TILE_K = 4;             // K reduction
constexpr int32_t TILE_COLS = 2;          // N column-pair
constexpr int32_t OUTPUT_COLS_PER_BLOCK = 8;

// Pack Q rows into BFMMLA-friendly interleaved layout
FORCE_INLINE void reshape_Q_2xK_for_bfmmla(
    const c10::BFloat16* r0, const c10::BFloat16* r1,
    c10::BFloat16* dst, int32_t K);

// 2x2 accumulator load/store
template <int32_t m_rows>
FORCE_INLINE float32x4_t load_acc_2x2(float* base, int64_t ldc, int col_off);
template <int32_t m_rows>
FORCE_INLINE void store_acc_2x2(float32x4_t acc, float* base, int64_t ldc, int col_off);

// Micro-kernel: RP row-pairs x 8 columns using BFMMLA
template <int32_t RP, int32_t K_static, AttentionGemmPhase phase>
FORCE_INLINE void gemm_rowpairs_x8_bfmmla_neon(
    const bfloat16_t* const* A_packed_rp,
    const int32_t* m_rows_rp,
    const bfloat16_t* B_blk,
    float* C, int64_t ldc, bool accumulate,
    int64_t b_stride, int32_t K_runtime = 0);

class TileGemmNEONBFMMLA { ... };
class AttentionImplNEONBFMMLA { ... };

} // namespace cpu_attention

Import

#include "cpu_attn_neon_bfmmla.hpp"

I/O Contract

Inputs

Name Type Required Description
A_packed_rp const bfloat16_t* const* Yes Array of pointers to pre-packed A matrix row-pairs in BFMMLA interleaved format
m_rows_rp const int32_t* Yes Array of active row counts (1 or 2) for each row-pair
B_blk const bfloat16_t* Yes Pointer to B matrix block (K or V cache); layout depends on attention phase
C float* Yes Pointer to output accumulator matrix; row-major float32
ldc int64_t Yes Leading dimension of C matrix
accumulate bool Yes Whether to load and accumulate into C or zero-initialize accumulators
b_stride int64_t Yes Stride between B columns; depends on cache block layout
K_static int32_t (template) Yes Compile-time K dimension; set to negative for runtime K (PV phase only)
phase AttentionGemmPhase (template) Yes QK or PV to select B matrix access pattern

Outputs

Name Type Description
C float* Updated output matrix with BFMMLA-computed GEMM results (attention scores or weighted values)

Usage Examples

#include "cpu_attn_neon_bfmmla.hpp"

// Pack Q rows for BFMMLA format
c10::BFloat16 packed_q[2 * round_up(head_dim, TILE_K)];
reshape_Q_2xK_for_bfmmla(q_row0, q_row1, packed_q, head_dim);

// Compute QK attention scores using BFMMLA
const bfloat16_t* a_ptrs[1] = { reinterpret_cast<const bfloat16_t*>(packed_q) };
int32_t m_rows[1] = { 2 };

gemm_rowpairs_x8_bfmmla_neon<1, HEAD_DIM, AttentionGemmPhase::QK>(
    a_ptrs, m_rows,
    reinterpret_cast<const bfloat16_t*>(k_cache_block),
    logits_output, ldc, /*accumulate=*/false, b_stride);

// Compute PV with runtime K dimension
gemm_rowpairs_x8_bfmmla_neon<1, -1, AttentionGemmPhase::PV>(
    a_ptrs, m_rows,
    reinterpret_cast<const bfloat16_t*>(v_cache_block),
    output, ldc, /*accumulate=*/true, b_stride, seq_tile_len);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment