Implementation:Vllm project Vllm SM100 FMHA MLA Reduction
| Knowledge Sources | |
|---|---|
| Domains | Attention, MLA, GPU_Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Defines the reduction kernel for SM100 (Blackwell) Fused Multi-Head Attention with Multi-Latent Attention (MLA), aggregating partial attention outputs from split-KV parallel computation.
Description
The Sm100FmhaMlaReductionKernel is a CUDA kernel template that reduces accumulated output tensors and log-sum-exp (LSE) values across multiple split-KV chunks into final attention output. It operates on NVIDIA SM100 (Blackwell) architecture using the CUTLASS framework. The kernel first computes a global LSE across all splits using warp-level reductions and shared memory, then rescales and accumulates partial output vectors into the final result.
Usage
This kernel is compiled as part of the CUTLASS-based MLA attention pipeline for SM100 GPUs. It is invoked after the main FMHA attention kernel completes split-KV partial computations, performing the final reduction step to produce coherent attention outputs for long-sequence scenarios.
Code Reference
Source Location
- Repository: vllm
- File: csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp
- Lines: 1-203
Signature
template<
class ElementOut,
class ElementAcc,
class ElementScale,
size_t kNumHeads,
size_t kHeadDimLatent,
int kMaxSplits
>
struct Sm100FmhaMlaReductionKernel {
struct Arguments {
ElementAcc* ptr_oaccum;
ElementOut* ptr_o;
ElementAcc* ptr_lseaccum;
ElementAcc* ptr_lse;
ElementScale scale;
int num_batches;
int split_kv;
int dim_k;
int* ptr_seq;
int* ptr_split_kv;
int tile_shape_s;
};
static Params to_underlying_arguments(Arguments const& args, void* workspace);
static size_t get_workspace_size(Arguments const& args);
static dim3 get_grid_shape(Params const& params);
static dim3 get_block_shape();
static bool can_implement(Arguments const& args);
CUTLASS_DEVICE void operator()(Params const& params, char* smem_raw);
};
Import
#include "csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| ptr_oaccum | ElementAcc* | Yes | Pointer to accumulated partial output tensor from split-KV attention [num_batches, kNumHeads, split_kv, kHeadDimLatent] |
| ptr_lseaccum | ElementAcc* | Yes | Pointer to accumulated log-sum-exp values per split [num_batches, kNumHeads, split_kv] |
| scale | ElementScale | Yes | Softmax scaling factor (default 1.0f) |
| num_batches | int | Yes | Number of batches in the attention operation |
| split_kv | int | Yes | Number of KV splits used during partial attention computation |
| dim_k | int | Yes | Sequence length dimension for KV |
| ptr_seq | int* | No | Per-batch sequence lengths (nullptr for uniform length) |
| ptr_split_kv | int* | No | Per-batch split counts (nullptr for uniform splits) |
| tile_shape_s | int | Yes | Tile shape along the sequence dimension (default 128) |
Outputs
| Name | Type | Description |
|---|---|---|
| ptr_o | ElementOut* | Final reduced attention output tensor [num_batches, kNumHeads, kHeadDimLatent] |
| ptr_lse | ElementAcc* | Final global log-sum-exp values [num_batches, kNumHeads] |
Usage Examples
using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel<
cutlass::half_t, // ElementOut
float, // ElementAcc
float, // ElementScale
32, // kNumHeads
512, // kHeadDimLatent
8 // kMaxSplits
>;
ReductionKernel::Arguments args;
args.ptr_oaccum = oaccum_ptr;
args.ptr_o = output_ptr;
args.ptr_lseaccum = lseaccum_ptr;
args.ptr_lse = lse_ptr;
args.scale = 1.0f;
args.num_batches = batch_size;
args.split_kv = num_splits;
args.dim_k = seq_len;
auto params = ReductionKernel::to_underlying_arguments(args, nullptr);
dim3 grid = ReductionKernel::get_grid_shape(params);
dim3 block = ReductionKernel::get_block_shape();