Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Sgl project Sglang SM100 MLA Reduction

From Leeroopedia


Knowledge Sources
Domains GPU Kernel, Attention
Last Updated 2026-02-10 00:00 GMT

Overview

Reduction kernel for combining split-K partial results from the MLA (Multi-head Latent Attention) attention kernel on NVIDIA SM100 GPUs.

Description

sm100_fmha_mla_reduction.hpp implements the Sm100FmhaMlaReductionKernel template within the cutlass::fmha::kernel namespace. This kernel reduces partial attention outputs across KV splits when using split-K parallelism, which divides the KV sequence across multiple thread blocks for higher GPU utilization on long sequences.

The kernel performs numerically stable online softmax reduction by:

  1. Computing per-split log-sum-exp (LSE) values from partial results
  2. Finding the global maximum LSE across all splits using warp-level reductions (__shfl_xor_sync)
  3. Computing scaling factors as exp(local_lse - global_lse) for each split
  4. Accumulating the weighted partial outputs using these scaling factors
  5. Writing the final combined attention output

Template parameters include ElementOut (output type), ElementAcc (accumulation type, typically float), ElementScale, kNumHeads, kHeadDimLatent, and kMaxSplits. The kernel uses 128 threads per block (MaxThreadsPerBlock = 128) and requires that kHeadDimLatent is divisible by the thread count.

The Arguments struct contains pointers to accumulated output (ptr_oaccum), final output (ptr_o), accumulated LSE (ptr_lseaccum), and final LSE (ptr_lse), along with batch dimensions and split configuration.

Usage

Use this kernel as the final reduction step after split-K MLA attention computation, where multiple thread blocks each compute partial attention over a subset of the KV sequence length.

Code Reference

Source Location

Signature

namespace cutlass::fmha::kernel {

template<
    class ElementOut,
    class ElementAcc,
    class ElementScale,
    size_t kNumHeads,
    size_t kHeadDimLatent,
    int kMaxSplits
>
struct Sm100FmhaMlaReductionKernel {
    static const int SharedStorageSize = 0;
    static const int MaxThreadsPerBlock = 128;
    static const int MinBlocksPerMultiprocessor = 1;
    using ArchTag = cutlass::arch::Sm100;

    struct Arguments {
        ElementAcc* ptr_oaccum = nullptr;
        ElementOut* ptr_o = nullptr;
        ElementAcc* ptr_lseaccum = nullptr;
        ElementAcc* ptr_lse = nullptr;
        ElementScale scale = 1.f;
        int num_batches = 0;
        int split_kv = -1;
        int dim_k = -1;
        int* ptr_seq = nullptr;
        int* ptr_split_kv = nullptr;
        int tile_shape_s = 128;
    };

    static Params to_underlying_arguments(Arguments const& args, void* workspace);
    static size_t get_workspace_size(Arguments const& args);
    static dim3 get_grid_shape(Params const& params);
    static dim3 get_block_shape();
    static bool can_implement(Arguments const& args);
    CUTLASS_DEVICE void operator()(Params const& params, char* smem_raw);
};

}  // namespace cutlass::fmha::kernel

Import

#include "cutlass/cutlass.h"
#include "cutlass/arch/arch.h"
#include "cute/tensor.hpp"

I/O Contract

Inputs

Name Type Required Description
ptr_oaccum ElementAcc* Yes Pointer to accumulated partial output from split-K attention blocks
ptr_lseaccum ElementAcc* Yes Pointer to accumulated LSE values from each split
scale ElementScale No (default: 1.0f) Softmax scaling factor
num_batches int Yes Number of batches in the input
split_kv int Yes Number of KV splits used
dim_k int Yes KV sequence dimension length
ptr_seq int* No Per-batch sequence length array (nullptr for uniform)
ptr_split_kv int* No Per-batch split count array (nullptr for uniform)
tile_shape_s int No (default: 128) Tile shape along sequence dimension

Outputs

Name Type Description
ptr_o ElementOut* Final reduced attention output tensor
ptr_lse ElementAcc* Final global log-sum-exp values

Usage Examples

// Typically invoked as part of the MLA attention pipeline:
using ReductionKernel = Sm100FmhaMlaReductionKernel<
    cutlass::half_t,   // ElementOut
    float,             // ElementAcc
    float,             // ElementScale
    64,                // kNumHeads
    512,               // kHeadDimLatent
    32                 // kMaxSplits
>;

ReductionKernel::Arguments args;
args.ptr_oaccum = partial_output_ptr;
args.ptr_o = final_output_ptr;
args.ptr_lseaccum = partial_lse_ptr;
args.ptr_lse = final_lse_ptr;
args.num_batches = batch_size;
args.split_kv = num_splits;

auto grid = ReductionKernel::get_grid_shape(params);
auto block = ReductionKernel::get_block_shape();

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment