Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Deepspeedai DeepSpeed Evoformer MMA Multistage

From Leeroopedia


Knowledge Sources
Domains Attention, CUTLASS_Kernels, DeepSpeed4Science
Last Updated 2026-02-09 00:00 GMT

Overview

A multi-stage threadblock-scoped GEMM implementation using software pipelining with async copy operations for overlapping data movement and computation on Ampere GPUs and later.

Description

CustomMmaMultistage implements a software-pipelined matrix multiplication kernel that uses multiple stages (typically 3-5) of shared memory buffers to hide global memory latency. It leverages CUDA's cp.async instructions (SM80+) to overlap asynchronous data transfers from global to shared memory with warp-level tensor core computations. The template is parameterized by tile shapes, iterator types for A/B operands, cache operation hints (CacheOpA/CacheOpB), and the number of pipeline stages. The implementation carefully orchestrates the prologue (initial loads), main loop (pipelined execution), and epilogue phases to maximize throughput while respecting shared memory constraints.

Usage

This MMA operator is instantiated in the Evoformer attention kernels when running on Ampere (SM80+) architecture with sufficient shared memory, providing higher performance than the double-buffered pipelined variant through deeper software pipelining.

Code Reference

Source Location

Signature

template <
    typename Shape_,                                    // Gemm problem shape
    typename IteratorA_,                                // Global memory iterator for A
    typename SmemIteratorA_,                            // Shared memory iterator for A
    cutlass::arch::CacheOperation::Kind CacheOpA,       // Cache hint for A
    typename IteratorB_,                                // Global memory iterator for B
    typename SmemIteratorB_,                            // Shared memory iterator for B
    cutlass::arch::CacheOperation::Kind CacheOpB,       // Cache hint for B
    typename ElementC_,                                 // Accumulator data type
    typename LayoutC_,                                  // Accumulator layout
    typename Policy_,                                   // MMA policy
    int Stages,                                         // Number of pipeline stages
    SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
    int kMaxK = cutlass::platform::numeric_limits<int>::max(),
    typename Enable = bool>
class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages>;

Import

#include "csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_multistage.h"

I/O Contract

Parameter Type Description
Inputs
iterator_A IteratorA Iterator over global memory tiles of operand A
iterator_B IteratorB Iterator over global memory tiles of operand B
gemm_k_iterations int Number of iterations along K dimension
Outputs
accum FragmentC Accumulator fragment containing C = A × B results
Configuration
Stages int Number of shared memory stages (3-5 typical)
kWarpGemmIterations int Warp-level GEMM iterations per threadblock tile

Usage Examples

// Configure multistage MMA for Ampere with 4 stages
using MmaMultistage = cutlass::gemm::threadblock::CustomMmaMultistage<
    cutlass::gemm::GemmShape<128, 128, 32>,  // Threadblock shape
    IteratorA,                                // A operand iterator
    SmemIteratorA,                            // A shared memory writer
    cutlass::arch::CacheOperation::Global,    // Cache globally
    IteratorB,                                // B operand iterator
    SmemIteratorB,                            // B shared memory writer
    cutlass::arch::CacheOperation::Global,    // Cache globally
    float,                                    // Accumulator type
    cutlass::layout::RowMajor,                // Accumulator layout
    MmaPolicy,                                // Warp MMA policy
    4                                         // 4-stage pipeline
>;

// Execute GEMM with software pipelining
MmaMultistage mma(shared_storage_A, shared_storage_B, thread_idx, warp_idx, lane_idx);
mma(gemm_k_iterations, accum, iterator_A, iterator_B);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment