Implementation:Deepspeedai DeepSpeed Evoformer MMA Multistage
| Knowledge Sources | |
|---|---|
| Domains | Attention, CUTLASS_Kernels, DeepSpeed4Science |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
A multi-stage threadblock-scoped GEMM implementation using software pipelining with async copy operations for overlapping data movement and computation on Ampere GPUs and later.
Description
CustomMmaMultistage implements a software-pipelined matrix multiplication kernel that uses multiple stages (typically 3-5) of shared memory buffers to hide global memory latency. It leverages CUDA's cp.async instructions (SM80+) to overlap asynchronous data transfers from global to shared memory with warp-level tensor core computations. The template is parameterized by tile shapes, iterator types for A/B operands, cache operation hints (CacheOpA/CacheOpB), and the number of pipeline stages. The implementation carefully orchestrates the prologue (initial loads), main loop (pipelined execution), and epilogue phases to maximize throughput while respecting shared memory constraints.
Usage
This MMA operator is instantiated in the Evoformer attention kernels when running on Ampere (SM80+) architecture with sufficient shared memory, providing higher performance than the double-buffered pipelined variant through deeper software pipelining.
Code Reference
Source Location
- Repository: DeepSpeed
- File: csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_multistage.h
Signature
template <
typename Shape_, // Gemm problem shape
typename IteratorA_, // Global memory iterator for A
typename SmemIteratorA_, // Shared memory iterator for A
cutlass::arch::CacheOperation::Kind CacheOpA, // Cache hint for A
typename IteratorB_, // Global memory iterator for B
typename SmemIteratorB_, // Shared memory iterator for B
cutlass::arch::CacheOperation::Kind CacheOpB, // Cache hint for B
typename ElementC_, // Accumulator data type
typename LayoutC_, // Accumulator layout
typename Policy_, // MMA policy
int Stages, // Number of pipeline stages
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
int kMaxK = cutlass::platform::numeric_limits<int>::max(),
typename Enable = bool>
class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages>;
Import
#include "csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_multistage.h"
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| Inputs | ||
| iterator_A | IteratorA | Iterator over global memory tiles of operand A |
| iterator_B | IteratorB | Iterator over global memory tiles of operand B |
| gemm_k_iterations | int | Number of iterations along K dimension |
| Outputs | ||
| accum | FragmentC | Accumulator fragment containing C = A × B results |
| Configuration | ||
| Stages | int | Number of shared memory stages (3-5 typical) |
| kWarpGemmIterations | int | Warp-level GEMM iterations per threadblock tile |
Usage Examples
// Configure multistage MMA for Ampere with 4 stages
using MmaMultistage = cutlass::gemm::threadblock::CustomMmaMultistage<
cutlass::gemm::GemmShape<128, 128, 32>, // Threadblock shape
IteratorA, // A operand iterator
SmemIteratorA, // A shared memory writer
cutlass::arch::CacheOperation::Global, // Cache globally
IteratorB, // B operand iterator
SmemIteratorB, // B shared memory writer
cutlass::arch::CacheOperation::Global, // Cache globally
float, // Accumulator type
cutlass::layout::RowMajor, // Accumulator layout
MmaPolicy, // Warp MMA policy
4 // 4-stage pipeline
>;
// Execute GEMM with software pipelining
MmaMultistage mma(shared_storage_A, shared_storage_B, thread_idx, warp_idx, lane_idx);
mma(gemm_k_iterations, accum, iterator_A, iterator_B);