Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Sgl project Sglang CPU Flash Attention

From Leeroopedia


Knowledge Sources
Domains Attention, CPU Compute
Last Updated 2026-02-10 00:00 GMT

Overview

Implements a CPU-optimized FlashAttention kernel for standard (non-paged) multi-head attention, supporting both causal and non-causal modes with grouped query attention (GQA).

Description

flash_attn.cpp provides a standalone FlashAttention implementation for CPU that operates on contiguous Q, K, V tensors without requiring a paged KV cache system. This distinguishes it from decode.cpp and extend.cpp, which are designed for paged attention.

The flash_attn_kernel_impl template function is parameterized by BLOCK_M and BLOCK_N tile sizes and implements the online softmax FlashAttention algorithm. The kernel parallelizes across [batches, num_heads, MB] using a custom parallel_for function with balance211 partitioning.

Each thread allocates scratch buffers from a pre-allocated buffer pool:

  • s_i (float): Attention score tile of shape [BLOCK_M, BLOCK_N]
  • s_delta (scalar_t): Converted scores for GEMM input, aliased with s_i
  • v_prime (float): Accumulated output values of shape [BLOCK_M, head_size_v]
  • Btmp (scalar_t): Packed key/value tiles of shape [BLOCK_N, max(head_size, head_size_v)]

The kernel supports different head sizes for keys (head_size) and values (head_size_v) to handle architectures like MLA (Multi-head Latent Attention). Grouped query attention is supported by computing the num_groups = num_heads / num_heads_kv ratio.

Licensed under BSD-3-Clause with copyright from both Codeplay Software and Intel Corporation.

Usage

Use this kernel for standard attention computation where Q, K, V are contiguous tensors (not in a paged cache). This is appropriate for training attention layers, benchmarking, or serving scenarios where paged attention is not needed.

Code Reference

Source Location

Signature

template <typename scalar_t, int BLOCK_M, int BLOCK_N>
void flash_attn_kernel_impl(
    scalar_t* __restrict__ out,
    const scalar_t* __restrict__ q,
    const scalar_t* __restrict__ k,
    const scalar_t* __restrict__ v,
    void* __restrict__ buffer,
    int seqlen_q,
    int seqlen_k,
    int batches,
    int num_heads,
    int num_heads_kv,
    int head_size,
    int head_size_v,
    int q_strideM, int q_strideH,
    int k_strideN, int k_strideH,
    int v_strideN, int v_strideH,
    float sm_scale,
    int buffer_size_per_thread,
    bool causal);

Import

#include "flash_attn.h"
#include "common.h"
#include "gemm.h"

I/O Contract

Inputs

Name Type Required Description
q scalar_t* Yes Query tensor, shape [batches * seqlen_q, num_heads, head_size]
k scalar_t* Yes Key tensor, shape [batches * seqlen_k, num_heads_kv, head_size]
v scalar_t* Yes Value tensor, shape [batches * seqlen_k, num_heads_kv, head_size_v]
buffer void* Yes Pre-allocated scratch buffer for per-thread working memory
seqlen_q int Yes Query sequence length
seqlen_k int Yes Key/value sequence length
batches int Yes Number of batch items
num_heads int Yes Number of query attention heads
num_heads_kv int Yes Number of key/value attention heads (for GQA)
head_size int Yes Key head dimension
head_size_v int Yes Value head dimension (may differ from head_size)
sm_scale float Yes Softmax scale factor (typically 1/sqrt(head_size))
causal bool Yes Whether to apply causal masking

Outputs

Name Type Description
out scalar_t* Attention output, shape [batches * seqlen_q, num_heads, head_size_v]

Usage Examples

Standard Flash Attention Call

flash_attn_kernel_impl<at::BFloat16, /*BLOCK_M=*/64, /*BLOCK_N=*/256>(
    output_ptr, query_ptr, key_ptr, value_ptr,
    scratch_buffer,
    seqlen_q, seqlen_k,
    batches, num_heads, num_heads_kv,
    head_size, head_size_v,
    q_strideM, q_strideH,
    k_strideN, k_strideH,
    v_strideN, v_strideH,
    /*sm_scale=*/1.0f / sqrtf(head_size),
    buffer_size_per_thread,
    /*causal=*/true);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment