Implementation:Sgl project Sglang SM100 MLA Reduction
| Knowledge Sources | |
|---|---|
| Domains | GPU Kernel, Attention |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Reduction kernel for combining split-K partial results from the MLA (Multi-head Latent Attention) attention kernel on NVIDIA SM100 GPUs.
Description
sm100_fmha_mla_reduction.hpp implements the Sm100FmhaMlaReductionKernel template within the cutlass::fmha::kernel namespace. This kernel reduces partial attention outputs across KV splits when using split-K parallelism, which divides the KV sequence across multiple thread blocks for higher GPU utilization on long sequences.
The kernel performs numerically stable online softmax reduction by:
- Computing per-split log-sum-exp (LSE) values from partial results
- Finding the global maximum LSE across all splits using warp-level reductions (__shfl_xor_sync)
- Computing scaling factors as exp(local_lse - global_lse) for each split
- Accumulating the weighted partial outputs using these scaling factors
- Writing the final combined attention output
Template parameters include ElementOut (output type), ElementAcc (accumulation type, typically float), ElementScale, kNumHeads, kHeadDimLatent, and kMaxSplits. The kernel uses 128 threads per block (MaxThreadsPerBlock = 128) and requires that kHeadDimLatent is divisible by the thread count.
The Arguments struct contains pointers to accumulated output (ptr_oaccum), final output (ptr_o), accumulated LSE (ptr_lseaccum), and final LSE (ptr_lse), along with batch dimensions and split configuration.
Usage
Use this kernel as the final reduction step after split-K MLA attention computation, where multiple thread blocks each compute partial attention over a subset of the KV sequence length.
Code Reference
Source Location
- Repository: Sgl_project_Sglang
- File: sgl-kernel/csrc/attention/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp
- Lines: 1-198
Signature
namespace cutlass::fmha::kernel {
template<
class ElementOut,
class ElementAcc,
class ElementScale,
size_t kNumHeads,
size_t kHeadDimLatent,
int kMaxSplits
>
struct Sm100FmhaMlaReductionKernel {
static const int SharedStorageSize = 0;
static const int MaxThreadsPerBlock = 128;
static const int MinBlocksPerMultiprocessor = 1;
using ArchTag = cutlass::arch::Sm100;
struct Arguments {
ElementAcc* ptr_oaccum = nullptr;
ElementOut* ptr_o = nullptr;
ElementAcc* ptr_lseaccum = nullptr;
ElementAcc* ptr_lse = nullptr;
ElementScale scale = 1.f;
int num_batches = 0;
int split_kv = -1;
int dim_k = -1;
int* ptr_seq = nullptr;
int* ptr_split_kv = nullptr;
int tile_shape_s = 128;
};
static Params to_underlying_arguments(Arguments const& args, void* workspace);
static size_t get_workspace_size(Arguments const& args);
static dim3 get_grid_shape(Params const& params);
static dim3 get_block_shape();
static bool can_implement(Arguments const& args);
CUTLASS_DEVICE void operator()(Params const& params, char* smem_raw);
};
} // namespace cutlass::fmha::kernel
Import
#include "cutlass/cutlass.h"
#include "cutlass/arch/arch.h"
#include "cute/tensor.hpp"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| ptr_oaccum | ElementAcc* | Yes | Pointer to accumulated partial output from split-K attention blocks |
| ptr_lseaccum | ElementAcc* | Yes | Pointer to accumulated LSE values from each split |
| scale | ElementScale | No (default: 1.0f) | Softmax scaling factor |
| num_batches | int | Yes | Number of batches in the input |
| split_kv | int | Yes | Number of KV splits used |
| dim_k | int | Yes | KV sequence dimension length |
| ptr_seq | int* | No | Per-batch sequence length array (nullptr for uniform) |
| ptr_split_kv | int* | No | Per-batch split count array (nullptr for uniform) |
| tile_shape_s | int | No (default: 128) | Tile shape along sequence dimension |
Outputs
| Name | Type | Description |
|---|---|---|
| ptr_o | ElementOut* | Final reduced attention output tensor |
| ptr_lse | ElementAcc* | Final global log-sum-exp values |
Usage Examples
// Typically invoked as part of the MLA attention pipeline:
using ReductionKernel = Sm100FmhaMlaReductionKernel<
cutlass::half_t, // ElementOut
float, // ElementAcc
float, // ElementScale
64, // kNumHeads
512, // kHeadDimLatent
32 // kMaxSplits
>;
ReductionKernel::Arguments args;
args.ptr_oaccum = partial_output_ptr;
args.ptr_o = final_output_ptr;
args.ptr_lseaccum = partial_lse_ptr;
args.ptr_lse = final_lse_ptr;
args.num_batches = batch_size;
args.split_kv = num_splits;
auto grid = ReductionKernel::get_grid_shape(params);
auto block = ReductionKernel::get_block_shape();