Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Sgl project Sglang SM100 MLA Device

From Leeroopedia


Knowledge Sources
Domains CUDA Kernels, Attention, CUTLASS
Last Updated 2026-02-10 00:00 GMT

Overview

A device-level CUTLASS 3.x orchestrator that manages kernel launch, workspace allocation, and split-K reduction for the SM100 Multi-head Latent Attention (MLA) kernel on Blackwell GPUs.

Description

sm100_mla.hpp implements the MLA device class within the cutlass::fmha::device namespace. It follows the standard CUTLASS 3.x device API pattern and serves as the top-level entry point for launching MLA decode attention on NVIDIA Blackwell (SM100) GPUs.

Core Architecture:

The class is templated on a Kernel_ type parameter (expected to be Sm100FmhaMlaKernelTmaWarpspecialized) and composes it with a ReductionKernel (Sm100FmhaMlaReductionKernel) for split-K parallelism. The maximum split count is 256.

Key Type Aliases:

  • Kernel: The main TMA warp-specialized FMHA kernel
  • ReductionKernel: Partial result reduction kernel parameterized by output types, tile shapes, and max splits
  • KernelArguments / Arguments: User-facing argument structure (problem shape, pointers, strides)
  • KernelParams / ReductionParams / Params: Internal parameter structures passed to device kernels

Primary Methods:

  • set_split_kv: Heuristically determines split-K factor based on SM count, batch size, and sequence length. Balances wave utilization against excessive splitting.
  • can_implement: Validates arguments against both the main kernel and reduction kernel constraints.
  • get_workspace_size: Computes total workspace bytes needed for both kernels.
  • maximum_active_blocks: Queries CUDA occupancy, configuring dynamic shared memory if the kernel requires more than 48 KB.
  • initialize: Converts user arguments to kernel parameters, allocates workspace for intermediate accumulation buffers (O_acc and LSE_acc when split_kv > 1), and sets dynamic shared memory attributes.
  • update: Lightweight parameter update without reinitializing shared memory configuration.
  • run (static): Launches the main FMHA kernel using cluster launch (for SM >= 90) or standard grid launch, then conditionally launches the reduction kernel when split_kv > 1.
  • run / operator(): Convenience overloads that combine initialize + run.

Split-K Workflow: When split_kv > 1, each CTA processes a subset of KV tiles and writes partial O and LSE results to workspace memory. After the main kernel completes, the reduction kernel combines these partial results into the final output.

Usage

Instantiate the MLA class with the appropriate kernel type, call set_split_kv to determine parallelism, then invoke run() with the arguments, workspace pointer, and CUDA stream.

Code Reference

Source Location

Signature

namespace cutlass::fmha::device {

template<class Kernel_>
class MLA {
public:
    using Kernel = Kernel_;
    using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel<
        typename Kernel::ElementOut,
        typename Kernel::ElementAcc,
        typename Kernel::ElementAcc,
        Kernel::TileShapeH::value,
        Kernel::TileShapeL::value,
        256 /*Max split*/
    >;

    using Arguments = typename Kernel::Arguments;
    struct Params { KernelParams fmha_params; ReductionParams reduction_params; };

    static void set_split_kv(KernelArguments& args);
    static Status can_implement(Arguments const& args);
    static size_t get_workspace_size(Arguments const& args);
    static int maximum_active_blocks(int smem_capacity = -1);

    Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr);
    Status update(Arguments const& args, void* workspace = nullptr);
    static Status run(Params& params, cudaStream_t stream = nullptr);
    Status run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr);
    Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr);
};

} // namespace cutlass::fmha::device

Import

#include "cutlass_sm100_mla/device/sm100_mla.hpp"

I/O Contract

Inputs

Name Type Required Description
problem_shape Shape<H, K, D, B> Yes (num_heads=128, seq_len, (d_latent, d_rope), batch)
mainloop.ptr_q_latent Element* Yes Pointer to Q latent tensor
mainloop.ptr_q_rope Element* Yes Pointer to Q rope tensor
mainloop.ptr_c_latent Element* Yes Pointer to compressed KV latent tensor
mainloop.ptr_k_rope Element* Yes Pointer to K rope tensor
mainloop.softmax_scale float Yes Softmax scaling factor
epilogue.ptr_o ElementOut* Yes Pointer to output tensor
epilogue.ptr_lse ElementLSE* No Pointer to log-sum-exp output
hw_info KernelHardwareInfo Yes SM count and hardware configuration
split_kv int No Split-K factor (-1 for auto-tuning)
workspace void* No Device workspace memory (required when split_kv > 1)
stream cudaStream_t No CUDA stream for kernel execution

Outputs

Name Type Description
ptr_o ElementOut* Output attention tensor (H x D_latent x B)
ptr_lse ElementLSE* Log-sum-exp values (H x B) for numerical stability
Status cutlass::Status kSuccess, kInvalid, kErrorWorkspaceNull, or kErrorInternal

Usage Examples

Complete MLA Launch Sequence

using MlaKernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
    Shape<_128, _512, Shape<_512, _64>>,
    cutlass::half_t, float, cutlass::half_t, float,
    TileScheduler>;

using MlaDevice = cutlass::fmha::device::MLA<MlaKernel>;

MlaDevice::Arguments args{
    {128, seq_len, {512, 64}, batch_size},
    {softmax_scale, ptr_q_latent, stride_q, ptr_q_rope, stride_qr,
     ptr_c_latent, stride_c, ptr_k_rope, stride_kr},
    {ptr_o, stride_o, ptr_lse, stride_lse},
    {sm_count}, -1 /*auto split*/
};

MlaDevice::set_split_kv(args);
size_t workspace_size = MlaDevice::get_workspace_size(args);
void* workspace = nullptr;
cudaMalloc(&workspace, workspace_size);

MlaDevice mla_op;
cutlass::Status status = mla_op.run(args, workspace, stream);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment