Implementation:Vllm project Vllm SM100 MLA Device
| Knowledge Sources | |
|---|---|
| Domains | Attention, CUDA, GPU_Kernels |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Provides a device-level interface for launching CUTLASS 3.x style Multi-Latent Attention (MLA) kernels optimized for NVIDIA SM100 (Blackwell) architecture.
Description
The cutlass::fmha::device::MLA template class wraps the SM100 FMHA MLA kernel and its associated reduction kernel, exposing a CUTLASS 3.x-compatible API for launching attention computations on SM100 GPUs. It handles workspace management, split-KV scheduling with wave-aware heuristics, kernel launch via cluster launch API, and optional reduction passes when split-KV is greater than one. The code was adapted from SGLANG PR #6929 by Alcanderian (JieXin Liang).
Usage
This header is compiled when building the vLLM CUTLASS MLA attention backend targeting NVIDIA Blackwell (SM100) GPUs. It is included by the MLA attention dispatch code and is used to instantiate and run fused multi-head latent attention kernels during inference.
Code Reference
Source Location
- Repository: vllm
- File: csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp
- Lines: 1-385
Signature
namespace cutlass::fmha::device {
template<class Kernel_>
class MLA {
public:
using Kernel = Kernel_;
using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel<...>;
using Arguments = typename Kernel::Arguments;
static void set_split_kv(KernelArguments& args);
static Status can_implement(Arguments const& args);
static size_t get_workspace_size(Arguments const& args);
static int maximum_active_blocks(int smem_capacity = -1);
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr);
Status update(Arguments const& args, void* workspace = nullptr);
static Status run(Params& params, cudaStream_t stream = nullptr);
Status run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr);
Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr);
};
} // namespace cutlass::fmha::device
Import
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| args | Arguments (KernelArguments) |
Yes | Contains problem_shape (H, K, D, B), mainloop arguments (Q/K pointers, softmax scale, page table), epilogue arguments (output/LSE pointers), and hardware info |
| workspace | void* |
No | GPU workspace for intermediate accumulation buffers; required when split_kv > 1 |
| stream | cudaStream_t |
No | CUDA stream for kernel launch; defaults to null (default stream) |
| split_kv | int |
No | Number of KV splits; auto-computed by set_split_kv if set to -1 |
Outputs
| Name | Type | Description |
|---|---|---|
| ptr_o | ElementOut* |
Output attention tensor written to the epilogue output pointer |
| ptr_lse | ElementLSE* |
Log-sum-exp values for numerical stability, written per head |
| Status | cutlass::Status |
Return code indicating success or failure of the kernel launch |
Usage Examples
// Instantiate the MLA device operator with a specific kernel type
using MLADevice = cutlass::fmha::device::MLA<MyMLAKernel>;
// Prepare arguments
typename MLADevice::Arguments args;
args.problem_shape = {num_heads, seq_len, make_shape(d_latent, d_rope), batch_size};
args.mainloop.softmax_scale = 1.0f / sqrt(head_dim);
args.mainloop.ptr_q_latent = q_latent_ptr;
args.mainloop.ptr_c_latent = c_latent_ptr;
args.epilogue.ptr_o = output_ptr;
args.hw_info.sm_count = sm_count;
// Auto-compute split_kv
MLADevice::set_split_kv(args);
// Check feasibility, allocate workspace, and run
size_t workspace_size = MLADevice::get_workspace_size(args);
void* workspace = allocate_gpu(workspace_size);
MLADevice mla_op;
auto status = mla_op.run(args, workspace, stream);