Implementation:Sgl project Sglang SM100 MLA Device
| Knowledge Sources | |
|---|---|
| Domains | CUDA Kernels, Attention, CUTLASS |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
A device-level CUTLASS 3.x orchestrator that manages kernel launch, workspace allocation, and split-K reduction for the SM100 Multi-head Latent Attention (MLA) kernel on Blackwell GPUs.
Description
sm100_mla.hpp implements the MLA device class within the cutlass::fmha::device namespace. It follows the standard CUTLASS 3.x device API pattern and serves as the top-level entry point for launching MLA decode attention on NVIDIA Blackwell (SM100) GPUs.
Core Architecture:
The class is templated on a Kernel_ type parameter (expected to be Sm100FmhaMlaKernelTmaWarpspecialized) and composes it with a ReductionKernel (Sm100FmhaMlaReductionKernel) for split-K parallelism. The maximum split count is 256.
Key Type Aliases:
- Kernel: The main TMA warp-specialized FMHA kernel
- ReductionKernel: Partial result reduction kernel parameterized by output types, tile shapes, and max splits
- KernelArguments / Arguments: User-facing argument structure (problem shape, pointers, strides)
- KernelParams / ReductionParams / Params: Internal parameter structures passed to device kernels
Primary Methods:
- set_split_kv: Heuristically determines split-K factor based on SM count, batch size, and sequence length. Balances wave utilization against excessive splitting.
- can_implement: Validates arguments against both the main kernel and reduction kernel constraints.
- get_workspace_size: Computes total workspace bytes needed for both kernels.
- maximum_active_blocks: Queries CUDA occupancy, configuring dynamic shared memory if the kernel requires more than 48 KB.
- initialize: Converts user arguments to kernel parameters, allocates workspace for intermediate accumulation buffers (O_acc and LSE_acc when split_kv > 1), and sets dynamic shared memory attributes.
- update: Lightweight parameter update without reinitializing shared memory configuration.
- run (static): Launches the main FMHA kernel using cluster launch (for SM >= 90) or standard grid launch, then conditionally launches the reduction kernel when
split_kv > 1. - run / operator(): Convenience overloads that combine initialize + run.
Split-K Workflow:
When split_kv > 1, each CTA processes a subset of KV tiles and writes partial O and LSE results to workspace memory. After the main kernel completes, the reduction kernel combines these partial results into the final output.
Usage
Instantiate the MLA class with the appropriate kernel type, call set_split_kv to determine parallelism, then invoke run() with the arguments, workspace pointer, and CUDA stream.
Code Reference
Source Location
- Repository: Sgl_project_Sglang
- File: sgl-kernel/csrc/attention/cutlass_sm100_mla/device/sm100_mla.hpp
- Lines: 1-359
Signature
namespace cutlass::fmha::device {
template<class Kernel_>
class MLA {
public:
using Kernel = Kernel_;
using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel<
typename Kernel::ElementOut,
typename Kernel::ElementAcc,
typename Kernel::ElementAcc,
Kernel::TileShapeH::value,
Kernel::TileShapeL::value,
256 /*Max split*/
>;
using Arguments = typename Kernel::Arguments;
struct Params { KernelParams fmha_params; ReductionParams reduction_params; };
static void set_split_kv(KernelArguments& args);
static Status can_implement(Arguments const& args);
static size_t get_workspace_size(Arguments const& args);
static int maximum_active_blocks(int smem_capacity = -1);
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr);
Status update(Arguments const& args, void* workspace = nullptr);
static Status run(Params& params, cudaStream_t stream = nullptr);
Status run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr);
Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr);
};
} // namespace cutlass::fmha::device
Import
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| problem_shape | Shape<H, K, D, B> | Yes | (num_heads=128, seq_len, (d_latent, d_rope), batch) |
| mainloop.ptr_q_latent | Element* | Yes | Pointer to Q latent tensor |
| mainloop.ptr_q_rope | Element* | Yes | Pointer to Q rope tensor |
| mainloop.ptr_c_latent | Element* | Yes | Pointer to compressed KV latent tensor |
| mainloop.ptr_k_rope | Element* | Yes | Pointer to K rope tensor |
| mainloop.softmax_scale | float | Yes | Softmax scaling factor |
| epilogue.ptr_o | ElementOut* | Yes | Pointer to output tensor |
| epilogue.ptr_lse | ElementLSE* | No | Pointer to log-sum-exp output |
| hw_info | KernelHardwareInfo | Yes | SM count and hardware configuration |
| split_kv | int | No | Split-K factor (-1 for auto-tuning) |
| workspace | void* | No | Device workspace memory (required when split_kv > 1) |
| stream | cudaStream_t | No | CUDA stream for kernel execution |
Outputs
| Name | Type | Description |
|---|---|---|
| ptr_o | ElementOut* | Output attention tensor (H x D_latent x B) |
| ptr_lse | ElementLSE* | Log-sum-exp values (H x B) for numerical stability |
| Status | cutlass::Status | kSuccess, kInvalid, kErrorWorkspaceNull, or kErrorInternal |
Usage Examples
Complete MLA Launch Sequence
using MlaKernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
Shape<_128, _512, Shape<_512, _64>>,
cutlass::half_t, float, cutlass::half_t, float,
TileScheduler>;
using MlaDevice = cutlass::fmha::device::MLA<MlaKernel>;
MlaDevice::Arguments args{
{128, seq_len, {512, 64}, batch_size},
{softmax_scale, ptr_q_latent, stride_q, ptr_q_rope, stride_qr,
ptr_c_latent, stride_c, ptr_k_rope, stride_kr},
{ptr_o, stride_o, ptr_lse, stride_lse},
{sm_count}, -1 /*auto split*/
};
MlaDevice::set_split_kv(args);
size_t workspace_size = MlaDevice::get_workspace_size(args);
void* workspace = nullptr;
cudaMalloc(&workspace, workspace_size);
MlaDevice mla_op;
cutlass::Status status = mla_op.run(args, workspace, stream);