Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:InternLM Lmdeploy AttentionCtaMap

From Leeroopedia


Knowledge Sources
Domains GPU_Kernels, Attention
Last Updated 2026-02-07 15:00 GMT

Overview

Defines CTA-to-problem mapping structs that translate CUDA grid block indices to logical (query, head, batch, split) coordinates for attention, decoding, and reduction kernels.

Description

This header provides three mapping structs: AttentionCtaMap maps a 3D grid (query_tiles x batch x split*heads) to prefill attention work items; DecodingCtaMap maps (cta_per_q_group*kv_heads x batch x split) for decoding attention; and ReduceCtaMap maps (heads x queries x split_tiles) for the split-K reduction kernel. Each struct provides get_grid_shape() for host-side launch configuration and device-side accessors (query_idx(), head_idx(), batch_idx(), split_idx()) that decode blockIdx.

Usage

Passed as a template parameter or runtime argument to AttentionUniversal and the reduction kernel to control work distribution across the GPU grid.

Code Reference

Source Location

Signature

namespace turbomind::attention {

struct AttentionCtaMap {
    __host__ __device__
    AttentionCtaMap(int max_q_len, int batch_size, int head_num, int cta_q, int cta_h, int split_cnt);
    __host__ __device__ void set_split_cnt(int value);
    __host__ dim3 get_grid_shape() const;
    __device__ int query_idx() const;
    __device__ int head_idx() const;
    __device__ int batch_idx() const;
    __device__ int split_idx() const;
    __device__ int split_count() const;
};

struct DecodingCtaMap {
    static __host__ dim3 get_grid_shape(int kv_head_num, int batch_size, int split_count, int cta_per_q_group);
    __device__ int query_idx() const;
    __device__ int head_idx() const;
    __device__ int batch_idx() const;
    __device__ int split_idx() const;
    __device__ int split_count() const;
};

struct ReduceCtaMap {
    static __host__ dim3 get_grid_shape(int query_num, int head_num, int max_split_cnt, int cta_k);
    static __device__ int query_idx();
    static __device__ int head_idx();
    static __device__ int split_idx();
};

} // namespace turbomind::attention

Import

#include "src/turbomind/kernels/attention/cta_map.h"

I/O Contract

Inputs

Name Type Required Description
max_q_len int Yes Maximum query sequence length in the batch
batch_size int Yes Number of sequences in the batch
head_num int Yes Number of attention heads (or KV heads for decoding)
cta_q int Yes CTA tile size along the query dimension
split_cnt int Yes Number of split-K partitions

Outputs

Name Type Description
get_grid_shape() dim3 CUDA grid dimensions for kernel launch
query_idx() int Logical query tile index for the current CTA
head_idx() int Logical head index for the current CTA
batch_idx() int Logical batch index for the current CTA
split_idx() int Split-K partition index for the current CTA

Usage Examples

AttentionCtaMap cta_map{max_q_len, batch_size, num_heads, CTA_Q, CTA_H, 1};
dim3 grid = cta_map.get_grid_shape();
cta_map.set_split_cnt(split_cnt);
grid = cta_map.get_grid_shape();
kernel<<<grid, block, smem, stream>>>(params, cache_iter, cta_map, ...);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment