Implementation:Deepspeedai DeepSpeed Evoformer MMA From Smem
| Knowledge Sources | |
|---|---|
| Domains | Attention, CUTLASS_Kernels, DeepSpeed4Science |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
A specialized GEMM operator that performs matrix multiplication with one operand residing in shared memory and the other loaded on-the-fly, optimized for memory-efficient attention patterns.
Description
MmaBaseFromSharedMemory implements a GEMM variant where operand B is pre-loaded into shared memory and reused across multiple GEMM operations with different A operands. This pattern is essential for memory-efficient attention where the value matrix (V) is loaded once into shared memory and used for multiple attention-weighted summations. The implementation uses circular buffering of shared memory when the K dimension exceeds available buffer size, and coordinates warp-level tensor core operations with shared memory access patterns. It includes AccumulatorSharedStorage for managing shared memory accumulator buffers between successive GEMMs, enabling efficient accumulation without expensive global memory round-trips.
Usage
This operator is used in the Evoformer attention backward pass and in the forward pass P×V computation, where V (values) is loaded into shared memory once and reused for computing outputs across multiple query blocks.
Code Reference
Source Location
- Repository: DeepSpeed
- File: csrc/deepspeed4science/evoformer_attn/gemm/mma_from_smem.h
Signature
// Shared storage for accumulators between GEMMs
template <typename Shape_, typename Element_, typename Layout_, typename Padding_>
class AccumulatorSharedStorage {
public:
using TensorRefAccum = cutlass::TensorRef<Element, Layout>;
cutlass::AlignedBuffer<Element, ShapeAccum::kCount> accum;
CUTLASS_HOST_DEVICE
TensorRefAccum accum_ref();
};
// GEMM with B operand in shared memory
template <
typename Shape_, // GEMM shape
int kMaxK, // Maximum K dimension
typename Policy_, // MMA policy
int Stages, // Pipeline stages
typename Enable = bool>
class MmaBaseFromSharedMemory {
public:
using Operator = typename Policy::Operator;
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
static bool const kSmemContainsEntireB = kMaxK <= Shape::kK * Stages;
class SharedStorage {
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
};
};
Import
#include "csrc/deepspeed4science/evoformer_attn/gemm/mma_from_smem.h"
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| Inputs | ||
| operand_A | TensorRefA | Operand A in global or shared memory |
| operand_B (smem) | SharedStorage | Operand B pre-loaded in shared memory |
| gemm_k_iterations | int | Number of K-dimension iterations |
| Outputs | ||
| accum | FragmentC | Accumulator fragment or shared memory buffer |
| Configuration | ||
| kMaxK | int | Upper bound on K dimension for buffer sizing |
| kSmemContainsEntireB | bool | True if entire B fits in shared memory stages |
Usage Examples
// Configure MMA with V matrix in shared memory
using MmaFromSmem = cutlass::gemm::threadblock::MmaBaseFromSharedMemory<
cutlass::gemm::GemmShape<64, 64, 32>, // Threadblock shape
128, // kMaxK = max head dimension
MmaPolicy, // Warp MMA policy
2 // 2 stages
>;
// Load V into shared memory once
typename MmaFromSmem::SharedStorage shared_storage_V;
load_value_matrix(shared_storage_V, V_ptr);
__syncthreads();
// Reuse V for multiple query blocks
for (int query_block = 0; query_block < num_query_blocks; ++query_block) {
MmaFromSmem mma(shared_storage_V, thread_idx, warp_idx, lane_idx);
mma(gemm_k_iterations, accum, iterator_P, shared_storage_V.operand_B);
}