Implementation:Sgl project Sglang CPU Flash Attention
| Knowledge Sources | |
|---|---|
| Domains | Attention, CPU Compute |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Implements a CPU-optimized FlashAttention kernel for standard (non-paged) multi-head attention, supporting both causal and non-causal modes with grouped query attention (GQA).
Description
flash_attn.cpp provides a standalone FlashAttention implementation for CPU that operates on contiguous Q, K, V tensors without requiring a paged KV cache system. This distinguishes it from decode.cpp and extend.cpp, which are designed for paged attention.
The flash_attn_kernel_impl template function is parameterized by BLOCK_M and BLOCK_N tile sizes and implements the online softmax FlashAttention algorithm. The kernel parallelizes across [batches, num_heads, MB] using a custom parallel_for function with balance211 partitioning.
Each thread allocates scratch buffers from a pre-allocated buffer pool:
- s_i (float): Attention score tile of shape [BLOCK_M, BLOCK_N]
- s_delta (scalar_t): Converted scores for GEMM input, aliased with s_i
- v_prime (float): Accumulated output values of shape [BLOCK_M, head_size_v]
- Btmp (scalar_t): Packed key/value tiles of shape [BLOCK_N, max(head_size, head_size_v)]
The kernel supports different head sizes for keys (head_size) and values (head_size_v) to handle architectures like MLA (Multi-head Latent Attention). Grouped query attention is supported by computing the num_groups = num_heads / num_heads_kv ratio.
Licensed under BSD-3-Clause with copyright from both Codeplay Software and Intel Corporation.
Usage
Use this kernel for standard attention computation where Q, K, V are contiguous tensors (not in a paged cache). This is appropriate for training attention layers, benchmarking, or serving scenarios where paged attention is not needed.
Code Reference
Source Location
- Repository: Sgl_project_Sglang
- File: sgl-kernel/csrc/cpu/flash_attn.cpp
- Lines: 1-544
Signature
template <typename scalar_t, int BLOCK_M, int BLOCK_N>
void flash_attn_kernel_impl(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ q,
const scalar_t* __restrict__ k,
const scalar_t* __restrict__ v,
void* __restrict__ buffer,
int seqlen_q,
int seqlen_k,
int batches,
int num_heads,
int num_heads_kv,
int head_size,
int head_size_v,
int q_strideM, int q_strideH,
int k_strideN, int k_strideH,
int v_strideN, int v_strideH,
float sm_scale,
int buffer_size_per_thread,
bool causal);
Import
#include "flash_attn.h"
#include "common.h"
#include "gemm.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| q | scalar_t* | Yes | Query tensor, shape [batches * seqlen_q, num_heads, head_size] |
| k | scalar_t* | Yes | Key tensor, shape [batches * seqlen_k, num_heads_kv, head_size] |
| v | scalar_t* | Yes | Value tensor, shape [batches * seqlen_k, num_heads_kv, head_size_v] |
| buffer | void* | Yes | Pre-allocated scratch buffer for per-thread working memory |
| seqlen_q | int | Yes | Query sequence length |
| seqlen_k | int | Yes | Key/value sequence length |
| batches | int | Yes | Number of batch items |
| num_heads | int | Yes | Number of query attention heads |
| num_heads_kv | int | Yes | Number of key/value attention heads (for GQA) |
| head_size | int | Yes | Key head dimension |
| head_size_v | int | Yes | Value head dimension (may differ from head_size) |
| sm_scale | float | Yes | Softmax scale factor (typically 1/sqrt(head_size)) |
| causal | bool | Yes | Whether to apply causal masking |
Outputs
| Name | Type | Description |
|---|---|---|
| out | scalar_t* | Attention output, shape [batches * seqlen_q, num_heads, head_size_v] |
Usage Examples
Standard Flash Attention Call
flash_attn_kernel_impl<at::BFloat16, /*BLOCK_M=*/64, /*BLOCK_N=*/256>(
output_ptr, query_ptr, key_ptr, value_ptr,
scratch_buffer,
seqlen_q, seqlen_k,
batches, num_heads, num_heads_kv,
head_size, head_size_v,
q_strideM, q_strideH,
k_strideN, k_strideH,
v_strideN, v_strideH,
/*sm_scale=*/1.0f / sqrtf(head_size),
buffer_size_per_thread,
/*causal=*/true);