Implementation:InternLM Lmdeploy AttentionCtaMap
| Knowledge Sources | |
|---|---|
| Domains | GPU_Kernels, Attention |
| Last Updated | 2026-02-07 15:00 GMT |
Overview
Defines CTA-to-problem mapping structs that translate CUDA grid block indices to logical (query, head, batch, split) coordinates for attention, decoding, and reduction kernels.
Description
This header provides three mapping structs: AttentionCtaMap maps a 3D grid (query_tiles x batch x split*heads) to prefill attention work items; DecodingCtaMap maps (cta_per_q_group*kv_heads x batch x split) for decoding attention; and ReduceCtaMap maps (heads x queries x split_tiles) for the split-K reduction kernel. Each struct provides get_grid_shape() for host-side launch configuration and device-side accessors (query_idx(), head_idx(), batch_idx(), split_idx()) that decode blockIdx.
Usage
Passed as a template parameter or runtime argument to AttentionUniversal and the reduction kernel to control work distribution across the GPU grid.
Code Reference
Source Location
- Repository: InternLM_Lmdeploy
- File: src/turbomind/kernels/attention/cta_map.h
- Lines: 1-149
Signature
namespace turbomind::attention {
struct AttentionCtaMap {
__host__ __device__
AttentionCtaMap(int max_q_len, int batch_size, int head_num, int cta_q, int cta_h, int split_cnt);
__host__ __device__ void set_split_cnt(int value);
__host__ dim3 get_grid_shape() const;
__device__ int query_idx() const;
__device__ int head_idx() const;
__device__ int batch_idx() const;
__device__ int split_idx() const;
__device__ int split_count() const;
};
struct DecodingCtaMap {
static __host__ dim3 get_grid_shape(int kv_head_num, int batch_size, int split_count, int cta_per_q_group);
__device__ int query_idx() const;
__device__ int head_idx() const;
__device__ int batch_idx() const;
__device__ int split_idx() const;
__device__ int split_count() const;
};
struct ReduceCtaMap {
static __host__ dim3 get_grid_shape(int query_num, int head_num, int max_split_cnt, int cta_k);
static __device__ int query_idx();
static __device__ int head_idx();
static __device__ int split_idx();
};
} // namespace turbomind::attention
Import
#include "src/turbomind/kernels/attention/cta_map.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| max_q_len | int | Yes | Maximum query sequence length in the batch |
| batch_size | int | Yes | Number of sequences in the batch |
| head_num | int | Yes | Number of attention heads (or KV heads for decoding) |
| cta_q | int | Yes | CTA tile size along the query dimension |
| split_cnt | int | Yes | Number of split-K partitions |
Outputs
| Name | Type | Description |
|---|---|---|
| get_grid_shape() | dim3 | CUDA grid dimensions for kernel launch |
| query_idx() | int | Logical query tile index for the current CTA |
| head_idx() | int | Logical head index for the current CTA |
| batch_idx() | int | Logical batch index for the current CTA |
| split_idx() | int | Split-K partition index for the current CTA |
Usage Examples
AttentionCtaMap cta_map{max_q_len, batch_size, num_heads, CTA_Q, CTA_H, 1};
dim3 grid = cta_map.get_grid_shape();
cta_map.set_split_cnt(split_cnt);
grid = cta_map.get_grid_shape();
kernel<<<grid, block, smem, stream>>>(params, cache_iter, cta_map, ...);