Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Vllm project Vllm SM100 FMHA MLA TMA Warpspecialized

From Leeroopedia


Knowledge Sources
Domains Attention, CUDA, GPU_Kernels, CUTLASS
Last Updated 2026-02-08 00:00 GMT

Overview

Implements the warp-specialized SM100 FMHA Multi-Latent Attention kernel using TMA (Tensor Memory Accelerator) for efficient asynchronous memory access on NVIDIA Blackwell GPUs.

Description

The Sm100FmhaMlaKernelTmaWarpspecialized template struct defines the core compute kernel for multi-latent attention on SM100 architecture. It leverages 2-SM mode with warp specialization (MMA, Load, Compute, and LoadPageTable roles), TMA-based asynchronous copy pipelines, and CUTLASS collective builder APIs for constructing optimized QK and PV matrix multiplications. The kernel supports both latent and rope head dimensions, paged attention via page tables, split-KV for long sequences, and shared memory staging between K and V operands.

Usage

This header is compiled as part of the vLLM CUTLASS MLA attention backend for SM100 GPUs. It is instantiated by the device-level MLA class and executed as a CUDA kernel during inference when Blackwell hardware is detected.

Code Reference

Source Location

Signature

namespace cutlass::fmha::kernel {

template<
    class TileShape,
    class Element_,
    class ElementAcc_,
    class ElementOut_,
    class ElementLSE_,
    class TileScheduler,
    bool kIsCpAsync = false
>
struct Sm100FmhaMlaKernelTmaWarpspecialized {
  using Element = Element_;
  using ElementAcc = ElementAcc_;
  using ElementOut = ElementOut_;
  using ElementLSE = ElementLSE_;

  static const bool kIs2Sm = true;
  static const int MaxThreadsPerBlock = 256;

  using ProblemShape = Shape<TileShapeH, int, TileShapeD, int>;

  struct MainloopArguments { ... };
  struct EpilogueArguments { ... };
  struct Arguments { ... };
  struct Params { ... };

  static Status can_implement(Arguments const& args);
  static size_t get_workspace_size(Arguments const& args);
  static Params to_underlying_arguments(Arguments const& args, void* workspace);
};

} // namespace cutlass::fmha::kernel

Import

#include "cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp"

I/O Contract

Inputs

Name Type Required Description
ptr_q_latent Element* Yes Pointer to the query latent tensor with stride (num_heads, 1, batch)
ptr_q_rope Element* Yes Pointer to the query rope tensor for rotary position encoding
ptr_c_latent Element* Yes Pointer to the compressed KV latent cache tensor
ptr_k_rope Element* Yes Pointer to the key rope tensor for rotary position encoding
softmax_scale ElementAcc Yes Softmax scaling factor (typically 1/sqrt(head_dim))
ptr_seq int* No Pointer to sequence length array per batch; nullptr for fixed-length
ptr_page_table int* No Pointer to page table for paged attention; nullptr for contiguous KV
problem_shape ProblemShape Yes Tuple of (num_heads=128, seqlen, (d_latent, d_rope), batch_count)
split_kv int No Number of KV splits for long sequence handling

Outputs

Name Type Description
ptr_o ElementOut* Output attention tensor; shape matches query layout
ptr_lse ElementLSE* Log-sum-exp values per query head for numerical stability

Usage Examples

// Define kernel type with tile shape and element types
using TileShape = Shape<_128, _128, Shape<_512, _64>>;
using MLAKernel = Sm100FmhaMlaKernelTmaWarpspecialized<
    TileShape,
    cutlass::half_t,   // Element
    float,             // ElementAcc
    cutlass::half_t,   // ElementOut
    float,             // ElementLSE
    MyTileScheduler
>;

// Set up arguments
typename MLAKernel::Arguments args;
args.problem_shape = {_128{}, seq_len, make_shape(_512{}, _64{}), batch};
args.mainloop.softmax_scale = scale;
args.mainloop.ptr_q_latent = q_latent;
args.mainloop.ptr_c_latent = c_latent;
args.mainloop.ptr_q_rope = q_rope;
args.mainloop.ptr_k_rope = k_rope;
args.epilogue.ptr_o = output;

// Launch via device-level MLA wrapper
cutlass::fmha::device::MLA<MLAKernel> mla;
mla.run(args, workspace, stream);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment