Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Vllm project Vllm SM100 FMHA MLA Reduction

From Leeroopedia


Knowledge Sources
Domains Attention, MLA, GPU_Inference
Last Updated 2026-02-08 00:00 GMT

Overview

Defines the reduction kernel for SM100 (Blackwell) Fused Multi-Head Attention with Multi-Latent Attention (MLA), aggregating partial attention outputs from split-KV parallel computation.

Description

The Sm100FmhaMlaReductionKernel is a CUDA kernel template that reduces accumulated output tensors and log-sum-exp (LSE) values across multiple split-KV chunks into final attention output. It operates on NVIDIA SM100 (Blackwell) architecture using the CUTLASS framework. The kernel first computes a global LSE across all splits using warp-level reductions and shared memory, then rescales and accumulates partial output vectors into the final result.

Usage

This kernel is compiled as part of the CUTLASS-based MLA attention pipeline for SM100 GPUs. It is invoked after the main FMHA attention kernel completes split-KV partial computations, performing the final reduction step to produce coherent attention outputs for long-sequence scenarios.

Code Reference

Source Location

Signature

template<
    class ElementOut,
    class ElementAcc,
    class ElementScale,
    size_t kNumHeads,
    size_t kHeadDimLatent,
    int kMaxSplits
>
struct Sm100FmhaMlaReductionKernel {
    struct Arguments {
        ElementAcc* ptr_oaccum;
        ElementOut* ptr_o;
        ElementAcc* ptr_lseaccum;
        ElementAcc* ptr_lse;
        ElementScale scale;
        int num_batches;
        int split_kv;
        int dim_k;
        int* ptr_seq;
        int* ptr_split_kv;
        int tile_shape_s;
    };

    static Params to_underlying_arguments(Arguments const& args, void* workspace);
    static size_t get_workspace_size(Arguments const& args);
    static dim3 get_grid_shape(Params const& params);
    static dim3 get_block_shape();
    static bool can_implement(Arguments const& args);
    CUTLASS_DEVICE void operator()(Params const& params, char* smem_raw);
};

Import

#include "csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp"

I/O Contract

Inputs

Name Type Required Description
ptr_oaccum ElementAcc* Yes Pointer to accumulated partial output tensor from split-KV attention [num_batches, kNumHeads, split_kv, kHeadDimLatent]
ptr_lseaccum ElementAcc* Yes Pointer to accumulated log-sum-exp values per split [num_batches, kNumHeads, split_kv]
scale ElementScale Yes Softmax scaling factor (default 1.0f)
num_batches int Yes Number of batches in the attention operation
split_kv int Yes Number of KV splits used during partial attention computation
dim_k int Yes Sequence length dimension for KV
ptr_seq int* No Per-batch sequence lengths (nullptr for uniform length)
ptr_split_kv int* No Per-batch split counts (nullptr for uniform splits)
tile_shape_s int Yes Tile shape along the sequence dimension (default 128)

Outputs

Name Type Description
ptr_o ElementOut* Final reduced attention output tensor [num_batches, kNumHeads, kHeadDimLatent]
ptr_lse ElementAcc* Final global log-sum-exp values [num_batches, kNumHeads]

Usage Examples

using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel<
    cutlass::half_t,   // ElementOut
    float,             // ElementAcc
    float,             // ElementScale
    32,                // kNumHeads
    512,               // kHeadDimLatent
    8                  // kMaxSplits
>;

ReductionKernel::Arguments args;
args.ptr_oaccum = oaccum_ptr;
args.ptr_o = output_ptr;
args.ptr_lseaccum = lseaccum_ptr;
args.ptr_lse = lse_ptr;
args.scale = 1.0f;
args.num_batches = batch_size;
args.split_kv = num_splits;
args.dim_k = seq_len;

auto params = ReductionKernel::to_underlying_arguments(args, nullptr);
dim3 grid = ReductionKernel::get_grid_shape(params);
dim3 block = ReductionKernel::get_block_shape();

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment