Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Sgl project Sglang CPU Extend Attention

From Leeroopedia


Knowledge Sources
Domains Attention, CPU Compute
Last Updated 2026-02-10 00:00 GMT

Overview

Implements the CPU extend (prefill) attention kernel, which computes attention for multiple new tokens being added to an existing prefix cache during the prefill phase.

Description

extend.cpp provides the CPU-optimized extend attention kernel for the prefill phase of LLM inference. The extend_attention_kernel_impl template function computes attention over both the prefix KV cache and newly extended tokens.

Key design features:

  • BLOCK_M and BLOCK_N: Tile sizes tuned for various sequence lengths
  • Non-contiguous support: Handles non-contiguous k_extend and v_extend tensors
  • Separate prefix/extend: Computes attention for prefix cache and new extend tokens independently
  • Prefix skip optimization: When is_prefix_skipped is true, bypasses prefix attention entirely

The kernel parallelizes across [batches, num_heads, BM] using at::parallel_for, where BM is the number of blocks along the M (query) dimension. Each thread receives dedicated scratch buffers:

  • s_i (float): Attention score matrix of shape [BLOCK_M, BLOCK_N]
  • s_delta (scalar_t): Converted scores for GEMM input, aliased with s_i
  • v_prime (float): Accumulated output values of shape [BLOCK_M, head_size_v]
  • Btmp (scalar_t): Packed key/value tile buffer of shape [BLOCK_N, max(head_size, head_size_v)]

The kernel maps requests to token positions via req_to_token and req_pool_indices arrays, supporting the paged memory layout used by the serving system.

Usage

This kernel is invoked during the prefill phase when processing multi-token inputs (prompts). It handles the case where new tokens must attend to both previously cached prefix tokens and the newly computed KV entries.

Code Reference

Source Location

Signature

template <typename scalar_t, typename index_t, int BLOCK_M, int BLOCK_N>
void extend_attention_kernel_impl(
    scalar_t* __restrict__ o_extend,
    const scalar_t* __restrict__ q_extend,
    const scalar_t* __restrict__ k_extend,
    const scalar_t* __restrict__ v_extend,
    const scalar_t* __restrict__ k_buffer,
    const scalar_t* __restrict__ v_buffer,
    const index_t* __restrict__ req_to_token,
    const int64_t* __restrict__ req_pool_indices,
    const int64_t* __restrict__ seq_lens,
    const index_t* __restrict__ extend_seq_lens,
    const index_t* __restrict__ extend_start_loc,
    const void* __restrict__ buffer,
    int batches, int num_heads, int num_heads_kv,
    int head_size, int head_size_v,
    int q_strideM, int q_strideH,
    int ke_strideN, int ke_strideH,
    int ve_strideN, int ve_strideH,
    int k_strideN, int k_strideH,
    int v_strideN, int v_strideH,
    float sm_scale,
    int max_num_reqs, int max_context_len,
    int max_total_num_tokens, int max_len_extend,
    int buffer_size_per_thread,
    bool is_prefix_skipped);

Import

#include "common.h"
#include "flash_attn.h"
#include "gemm.h"

I/O Contract

Inputs

Name Type Required Description
q_extend scalar_t* Yes Query tensor for new extend tokens
k_extend scalar_t* Yes Key tensor for new extend tokens (may be non-contiguous)
v_extend scalar_t* Yes Value tensor for new extend tokens (may be non-contiguous)
k_buffer scalar_t* Yes Paged key cache buffer for prefix tokens
v_buffer scalar_t* Yes Paged value cache buffer for prefix tokens
req_to_token index_t* Yes Maps (request, position) to physical token index in cache
req_pool_indices int64_t* Yes Maps batch index to request pool ID
seq_lens int64_t* Yes Total sequence lengths (prefix + extend) per batch item
extend_seq_lens index_t* Yes Number of new extend tokens per batch item
extend_start_loc index_t* Yes Start location of extend tokens in the packed Q/K/V tensors
sm_scale float Yes Softmax scale factor (typically 1/sqrt(head_size))
is_prefix_skipped bool Yes If true, skips prefix attention (prefix length must be 0)

Outputs

Name Type Description
o_extend scalar_t* Output tensor with attention results, shape [total_extend_tokens, num_heads, head_size_v]

Usage Examples

Extend Attention Invocation

extend_attention_kernel_impl<at::BFloat16, int32_t, /*BLOCK_M=*/64, /*BLOCK_N=*/256>(
    output_ptr, query_ptr, key_extend_ptr, value_extend_ptr,
    key_buffer_ptr, value_buffer_ptr,
    req_to_token_ptr, req_pool_indices_ptr,
    seq_lens_ptr, extend_seq_lens_ptr, extend_start_loc_ptr,
    scratch_buffer,
    batches, num_heads, num_heads_kv,
    head_size, head_size_v,
    q_strideM, q_strideH,
    ke_strideN, ke_strideH,
    ve_strideN, ve_strideH,
    k_strideN, k_strideH,
    v_strideN, v_strideH,
    sm_scale,
    max_num_reqs, max_context_len,
    max_total_num_tokens, max_len_extend,
    buffer_size_per_thread,
    /*is_prefix_skipped=*/false);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment