Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Deepspeedai DeepSpeed Evoformer MMA From Smem

From Leeroopedia


Knowledge Sources
Domains Attention, CUTLASS_Kernels, DeepSpeed4Science
Last Updated 2026-02-09 00:00 GMT

Overview

A specialized GEMM operator that performs matrix multiplication with one operand residing in shared memory and the other loaded on-the-fly, optimized for memory-efficient attention patterns.

Description

MmaBaseFromSharedMemory implements a GEMM variant where operand B is pre-loaded into shared memory and reused across multiple GEMM operations with different A operands. This pattern is essential for memory-efficient attention where the value matrix (V) is loaded once into shared memory and used for multiple attention-weighted summations. The implementation uses circular buffering of shared memory when the K dimension exceeds available buffer size, and coordinates warp-level tensor core operations with shared memory access patterns. It includes AccumulatorSharedStorage for managing shared memory accumulator buffers between successive GEMMs, enabling efficient accumulation without expensive global memory round-trips.

Usage

This operator is used in the Evoformer attention backward pass and in the forward pass P×V computation, where V (values) is loaded into shared memory once and reused for computing outputs across multiple query blocks.

Code Reference

Source Location

Signature

// Shared storage for accumulators between GEMMs
template <typename Shape_, typename Element_, typename Layout_, typename Padding_>
class AccumulatorSharedStorage {
public:
    using TensorRefAccum = cutlass::TensorRef<Element, Layout>;
    cutlass::AlignedBuffer<Element, ShapeAccum::kCount> accum;

    CUTLASS_HOST_DEVICE
    TensorRefAccum accum_ref();
};

// GEMM with B operand in shared memory
template <
    typename Shape_,        // GEMM shape
    int kMaxK,              // Maximum K dimension
    typename Policy_,       // MMA policy
    int Stages,             // Pipeline stages
    typename Enable = bool>
class MmaBaseFromSharedMemory {
public:
    using Operator = typename Policy::Operator;
    using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
    using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;

    static bool const kSmemContainsEntireB = kMaxK <= Shape::kK * Stages;

    class SharedStorage {
        AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
    };
};

Import

#include "csrc/deepspeed4science/evoformer_attn/gemm/mma_from_smem.h"

I/O Contract

Parameter Type Description
Inputs
operand_A TensorRefA Operand A in global or shared memory
operand_B (smem) SharedStorage Operand B pre-loaded in shared memory
gemm_k_iterations int Number of K-dimension iterations
Outputs
accum FragmentC Accumulator fragment or shared memory buffer
Configuration
kMaxK int Upper bound on K dimension for buffer sizing
kSmemContainsEntireB bool True if entire B fits in shared memory stages

Usage Examples

// Configure MMA with V matrix in shared memory
using MmaFromSmem = cutlass::gemm::threadblock::MmaBaseFromSharedMemory<
    cutlass::gemm::GemmShape<64, 64, 32>,  // Threadblock shape
    128,                                    // kMaxK = max head dimension
    MmaPolicy,                              // Warp MMA policy
    2                                       // 2 stages
>;

// Load V into shared memory once
typename MmaFromSmem::SharedStorage shared_storage_V;
load_value_matrix(shared_storage_V, V_ptr);
__syncthreads();

// Reuse V for multiple query blocks
for (int query_block = 0; query_block < num_query_blocks; ++query_block) {
    MmaFromSmem mma(shared_storage_V, thread_idx, warp_idx, lane_idx);
    mma(gemm_k_iterations, accum, iterator_P, shared_storage_V.operand_B);
}

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment