Implementation:Vllm project Vllm SM100 FMHA MLA TMA Warpspecialized
| Knowledge Sources | |
|---|---|
| Domains | Attention, CUDA, GPU_Kernels, CUTLASS |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements the warp-specialized SM100 FMHA Multi-Latent Attention kernel using TMA (Tensor Memory Accelerator) for efficient asynchronous memory access on NVIDIA Blackwell GPUs.
Description
The Sm100FmhaMlaKernelTmaWarpspecialized template struct defines the core compute kernel for multi-latent attention on SM100 architecture. It leverages 2-SM mode with warp specialization (MMA, Load, Compute, and LoadPageTable roles), TMA-based asynchronous copy pipelines, and CUTLASS collective builder APIs for constructing optimized QK and PV matrix multiplications. The kernel supports both latent and rope head dimensions, paged attention via page tables, split-KV for long sequences, and shared memory staging between K and V operands.
Usage
This header is compiled as part of the vLLM CUTLASS MLA attention backend for SM100 GPUs. It is instantiated by the device-level MLA class and executed as a CUDA kernel during inference when Blackwell hardware is detected.
Code Reference
Source Location
- Repository: vllm
- File: csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp
- Lines: 1-2023
Signature
namespace cutlass::fmha::kernel {
template<
class TileShape,
class Element_,
class ElementAcc_,
class ElementOut_,
class ElementLSE_,
class TileScheduler,
bool kIsCpAsync = false
>
struct Sm100FmhaMlaKernelTmaWarpspecialized {
using Element = Element_;
using ElementAcc = ElementAcc_;
using ElementOut = ElementOut_;
using ElementLSE = ElementLSE_;
static const bool kIs2Sm = true;
static const int MaxThreadsPerBlock = 256;
using ProblemShape = Shape<TileShapeH, int, TileShapeD, int>;
struct MainloopArguments { ... };
struct EpilogueArguments { ... };
struct Arguments { ... };
struct Params { ... };
static Status can_implement(Arguments const& args);
static size_t get_workspace_size(Arguments const& args);
static Params to_underlying_arguments(Arguments const& args, void* workspace);
};
} // namespace cutlass::fmha::kernel
Import
#include "cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| ptr_q_latent | Element* |
Yes | Pointer to the query latent tensor with stride (num_heads, 1, batch) |
| ptr_q_rope | Element* |
Yes | Pointer to the query rope tensor for rotary position encoding |
| ptr_c_latent | Element* |
Yes | Pointer to the compressed KV latent cache tensor |
| ptr_k_rope | Element* |
Yes | Pointer to the key rope tensor for rotary position encoding |
| softmax_scale | ElementAcc |
Yes | Softmax scaling factor (typically 1/sqrt(head_dim)) |
| ptr_seq | int* |
No | Pointer to sequence length array per batch; nullptr for fixed-length |
| ptr_page_table | int* |
No | Pointer to page table for paged attention; nullptr for contiguous KV |
| problem_shape | ProblemShape |
Yes | Tuple of (num_heads=128, seqlen, (d_latent, d_rope), batch_count) |
| split_kv | int |
No | Number of KV splits for long sequence handling |
Outputs
| Name | Type | Description |
|---|---|---|
| ptr_o | ElementOut* |
Output attention tensor; shape matches query layout |
| ptr_lse | ElementLSE* |
Log-sum-exp values per query head for numerical stability |
Usage Examples
// Define kernel type with tile shape and element types
using TileShape = Shape<_128, _128, Shape<_512, _64>>;
using MLAKernel = Sm100FmhaMlaKernelTmaWarpspecialized<
TileShape,
cutlass::half_t, // Element
float, // ElementAcc
cutlass::half_t, // ElementOut
float, // ElementLSE
MyTileScheduler
>;
// Set up arguments
typename MLAKernel::Arguments args;
args.problem_shape = {_128{}, seq_len, make_shape(_512{}, _64{}), batch};
args.mainloop.softmax_scale = scale;
args.mainloop.ptr_q_latent = q_latent;
args.mainloop.ptr_c_latent = c_latent;
args.mainloop.ptr_q_rope = q_rope;
args.mainloop.ptr_k_rope = k_rope;
args.epilogue.ptr_o = output;
// Launch via device-level MLA wrapper
cutlass::fmha::device::MLA<MLAKernel> mla;
mla.run(args, workspace, stream);