Implementation:Sgl project Sglang CPU Extend Attention
| Knowledge Sources | |
|---|---|
| Domains | Attention, CPU Compute |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Implements the CPU extend (prefill) attention kernel, which computes attention for multiple new tokens being added to an existing prefix cache during the prefill phase.
Description
extend.cpp provides the CPU-optimized extend attention kernel for the prefill phase of LLM inference. The extend_attention_kernel_impl template function computes attention over both the prefix KV cache and newly extended tokens.
Key design features:
- BLOCK_M and BLOCK_N: Tile sizes tuned for various sequence lengths
- Non-contiguous support: Handles non-contiguous k_extend and v_extend tensors
- Separate prefix/extend: Computes attention for prefix cache and new extend tokens independently
- Prefix skip optimization: When is_prefix_skipped is true, bypasses prefix attention entirely
The kernel parallelizes across [batches, num_heads, BM] using at::parallel_for, where BM is the number of blocks along the M (query) dimension. Each thread receives dedicated scratch buffers:
- s_i (float): Attention score matrix of shape [BLOCK_M, BLOCK_N]
- s_delta (scalar_t): Converted scores for GEMM input, aliased with s_i
- v_prime (float): Accumulated output values of shape [BLOCK_M, head_size_v]
- Btmp (scalar_t): Packed key/value tile buffer of shape [BLOCK_N, max(head_size, head_size_v)]
The kernel maps requests to token positions via req_to_token and req_pool_indices arrays, supporting the paged memory layout used by the serving system.
Usage
This kernel is invoked during the prefill phase when processing multi-token inputs (prompts). It handles the case where new tokens must attend to both previously cached prefix tokens and the newly computed KV entries.
Code Reference
Source Location
- Repository: Sgl_project_Sglang
- File: sgl-kernel/csrc/cpu/extend.cpp
- Lines: 1-432
Signature
template <typename scalar_t, typename index_t, int BLOCK_M, int BLOCK_N>
void extend_attention_kernel_impl(
scalar_t* __restrict__ o_extend,
const scalar_t* __restrict__ q_extend,
const scalar_t* __restrict__ k_extend,
const scalar_t* __restrict__ v_extend,
const scalar_t* __restrict__ k_buffer,
const scalar_t* __restrict__ v_buffer,
const index_t* __restrict__ req_to_token,
const int64_t* __restrict__ req_pool_indices,
const int64_t* __restrict__ seq_lens,
const index_t* __restrict__ extend_seq_lens,
const index_t* __restrict__ extend_start_loc,
const void* __restrict__ buffer,
int batches, int num_heads, int num_heads_kv,
int head_size, int head_size_v,
int q_strideM, int q_strideH,
int ke_strideN, int ke_strideH,
int ve_strideN, int ve_strideH,
int k_strideN, int k_strideH,
int v_strideN, int v_strideH,
float sm_scale,
int max_num_reqs, int max_context_len,
int max_total_num_tokens, int max_len_extend,
int buffer_size_per_thread,
bool is_prefix_skipped);
Import
#include "common.h"
#include "flash_attn.h"
#include "gemm.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| q_extend | scalar_t* | Yes | Query tensor for new extend tokens |
| k_extend | scalar_t* | Yes | Key tensor for new extend tokens (may be non-contiguous) |
| v_extend | scalar_t* | Yes | Value tensor for new extend tokens (may be non-contiguous) |
| k_buffer | scalar_t* | Yes | Paged key cache buffer for prefix tokens |
| v_buffer | scalar_t* | Yes | Paged value cache buffer for prefix tokens |
| req_to_token | index_t* | Yes | Maps (request, position) to physical token index in cache |
| req_pool_indices | int64_t* | Yes | Maps batch index to request pool ID |
| seq_lens | int64_t* | Yes | Total sequence lengths (prefix + extend) per batch item |
| extend_seq_lens | index_t* | Yes | Number of new extend tokens per batch item |
| extend_start_loc | index_t* | Yes | Start location of extend tokens in the packed Q/K/V tensors |
| sm_scale | float | Yes | Softmax scale factor (typically 1/sqrt(head_size)) |
| is_prefix_skipped | bool | Yes | If true, skips prefix attention (prefix length must be 0) |
Outputs
| Name | Type | Description |
|---|---|---|
| o_extend | scalar_t* | Output tensor with attention results, shape [total_extend_tokens, num_heads, head_size_v] |
Usage Examples
Extend Attention Invocation
extend_attention_kernel_impl<at::BFloat16, int32_t, /*BLOCK_M=*/64, /*BLOCK_N=*/256>(
output_ptr, query_ptr, key_extend_ptr, value_extend_ptr,
key_buffer_ptr, value_buffer_ptr,
req_to_token_ptr, req_pool_indices_ptr,
seq_lens_ptr, extend_seq_lens_ptr, extend_start_loc_ptr,
scratch_buffer,
batches, num_heads, num_heads_kv,
head_size, head_size_v,
q_strideM, q_strideH,
ke_strideN, ke_strideH,
ve_strideN, ve_strideH,
k_strideN, k_strideH,
v_strideN, v_strideH,
sm_scale,
max_num_reqs, max_context_len,
max_total_num_tokens, max_len_extend,
buffer_size_per_thread,
/*is_prefix_skipped=*/false);