Implementation:Deepspeedai DeepSpeed Evoformer Warp Iterator Smem
| Knowledge Sources | |
|---|---|
| Domains | Attention, CUTLASS_Kernels, DeepSpeed4Science |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
A warp-level iterator that loads matrix tiles from shared memory into tensor core registers using ldmatrix instructions with optional transpose support.
Description
WarpIteratorFromSmem provides optimized loading of matrix operands from shared memory for tensor core operations on Ampere GPUs. Unlike CUTLASS's standard iterators that use manual element-wise loads, this implementation uses the ldmatrix.sync.aligned instruction family to load entire 8×8 or 16×8 matrix fragments directly into tensor core register layout in a single instruction. The iterator supports both normal and transposed access patterns through the kTranspose template parameter, enabling efficient loading of A^T or B^T without explicit transposition. It computes per-thread coordinate offsets based on lane ID and warp position, ensuring each thread loads the correct matrix fragment portion according to tensor core distribution patterns. The 32×32 tile shape is hardcoded to match common attention block sizes.
Usage
This iterator is used in the MmaFromSmem operations when loading the value matrix from shared memory in attention computations, providing 2-3× speedup over element-wise loads through direct ldmatrix utilization.
Code Reference
Source Location
- Repository: DeepSpeed
- File: csrc/deepspeed4science/evoformer_attn/iterators/warp_iterator_from_smem.h
Signature
template <
Operand Operand_, // Operand::kA or Operand::kB
typename Element_, // Element type (must be 16-bit)
bool kTranspose = false> // Load transposed
class WarpIteratorFromSmem {
public:
using Shape = cutlass::MatrixShape<32, 32>;
using Element = Element_;
using Layout = cutlass::layout::RowMajor;
using InstructionShape = cutlass::MatrixShape<16, 8>;
static_assert(sizeof_bits<Element>::value == 16, "Only 16-bit types supported");
static int const kOperand = Operand_;
static int const kIterations = (kOperand == Operand::kA) ?
InstructionCount::kColumn : InstructionCount::kRow;
using Fragment = Array<Element, kOperand == Operand::kA ?
(Shape::kRow * InstructionShape::kColumn / 32) :
(Shape::kColumn * InstructionShape::kRow / 32)>;
using TensorRef = TensorRef<Element, Layout>;
using AccessType = Array<unsigned, 4>; // ldmatrix loads 4 x 32-bit
CUTLASS_HOST_DEVICE
WarpIteratorFromSmem(TensorRef const& ref, int lane_id);
CUTLASS_DEVICE void load(Fragment& frag) const;
CUTLASS_DEVICE void advance();
CUTLASS_HOST_DEVICE WarpIteratorFromSmem& operator++();
};
Import
#include "csrc/deepspeed4science/evoformer_attn/iterators/warp_iterator_from_smem.h"
I/O Contract
| Method | Input | Output | Description |
|---|---|---|---|
| Constructor | TensorRef, lane_id | Iterator | Initialize with shared memory reference and thread lane |
| load | Fragment& (output) | void | Load 32×32 tile from shared memory using ldmatrix |
| advance | void | void | Move to next tile position (column for A, row for B) |
| operator++ | void | Iterator& | Increment and advance to next iteration |
| Configuration | |||
| kTranspose | bool | If true, load A^T or B^T | |
| Operand | enum | Operand::kA or Operand::kB |
Usage Examples
// Load value matrix from shared memory with ldmatrix
using ValueIterator = cutlass::gemm::warp::WarpIteratorFromSmem<
cutlass::gemm::Operand::kB, // B operand in A×B
cutlass::half_t, // FP16
false // No transpose
>;
// Shared memory reference for V matrix
__shared__ cutlass::half_t smem_V[32][32 + 4]; // +4 for padding
typename ValueIterator::TensorRef ref_V({smem_V[0], 32 + 4});
// Create iterator for this warp
ValueIterator iter_V(ref_V, lane_id);
// Load fragment using ldmatrix instruction
typename ValueIterator::Fragment frag_V;
iter_V.load(frag_V); // Single ldmatrix.sync.aligned.m8n8.x4 instruction
// Use fragment in tensor core MMA
mma(accum, frag_A, frag_V, accum);
// Advance to next K-dimension tile
++iter_V;
iter_V.load(frag_V); // Load next chunk