Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Deepspeedai DeepSpeed Evoformer Warp Iterator Smem

From Leeroopedia


Knowledge Sources
Domains Attention, CUTLASS_Kernels, DeepSpeed4Science
Last Updated 2026-02-09 00:00 GMT

Overview

A warp-level iterator that loads matrix tiles from shared memory into tensor core registers using ldmatrix instructions with optional transpose support.

Description

WarpIteratorFromSmem provides optimized loading of matrix operands from shared memory for tensor core operations on Ampere GPUs. Unlike CUTLASS's standard iterators that use manual element-wise loads, this implementation uses the ldmatrix.sync.aligned instruction family to load entire 8×8 or 16×8 matrix fragments directly into tensor core register layout in a single instruction. The iterator supports both normal and transposed access patterns through the kTranspose template parameter, enabling efficient loading of A^T or B^T without explicit transposition. It computes per-thread coordinate offsets based on lane ID and warp position, ensuring each thread loads the correct matrix fragment portion according to tensor core distribution patterns. The 32×32 tile shape is hardcoded to match common attention block sizes.

Usage

This iterator is used in the MmaFromSmem operations when loading the value matrix from shared memory in attention computations, providing 2-3× speedup over element-wise loads through direct ldmatrix utilization.

Code Reference

Source Location

Signature

template <
    Operand Operand_,           // Operand::kA or Operand::kB
    typename Element_,          // Element type (must be 16-bit)
    bool kTranspose = false>    // Load transposed
class WarpIteratorFromSmem {
public:
    using Shape = cutlass::MatrixShape<32, 32>;
    using Element = Element_;
    using Layout = cutlass::layout::RowMajor;
    using InstructionShape = cutlass::MatrixShape<16, 8>;

    static_assert(sizeof_bits<Element>::value == 16, "Only 16-bit types supported");

    static int const kOperand = Operand_;
    static int const kIterations = (kOperand == Operand::kA) ?
        InstructionCount::kColumn : InstructionCount::kRow;

    using Fragment = Array<Element, kOperand == Operand::kA ?
        (Shape::kRow * InstructionShape::kColumn / 32) :
        (Shape::kColumn * InstructionShape::kRow / 32)>;

    using TensorRef = TensorRef<Element, Layout>;
    using AccessType = Array<unsigned, 4>;  // ldmatrix loads 4 x 32-bit

    CUTLASS_HOST_DEVICE
    WarpIteratorFromSmem(TensorRef const& ref, int lane_id);

    CUTLASS_DEVICE void load(Fragment& frag) const;
    CUTLASS_DEVICE void advance();
    CUTLASS_HOST_DEVICE WarpIteratorFromSmem& operator++();
};

Import

#include "csrc/deepspeed4science/evoformer_attn/iterators/warp_iterator_from_smem.h"

I/O Contract

Method Input Output Description
Constructor TensorRef, lane_id Iterator Initialize with shared memory reference and thread lane
load Fragment& (output) void Load 32×32 tile from shared memory using ldmatrix
advance void void Move to next tile position (column for A, row for B)
operator++ void Iterator& Increment and advance to next iteration
Configuration
kTranspose bool If true, load A^T or B^T
Operand enum Operand::kA or Operand::kB

Usage Examples

// Load value matrix from shared memory with ldmatrix
using ValueIterator = cutlass::gemm::warp::WarpIteratorFromSmem<
    cutlass::gemm::Operand::kB,  // B operand in A×B
    cutlass::half_t,              // FP16
    false                         // No transpose
>;

// Shared memory reference for V matrix
__shared__ cutlass::half_t smem_V[32][32 + 4];  // +4 for padding
typename ValueIterator::TensorRef ref_V({smem_V[0], 32 + 4});

// Create iterator for this warp
ValueIterator iter_V(ref_V, lane_id);

// Load fragment using ldmatrix instruction
typename ValueIterator::Fragment frag_V;
iter_V.load(frag_V);  // Single ldmatrix.sync.aligned.m8n8.x4 instruction

// Use fragment in tensor core MMA
mma(accum, frag_A, frag_V, accum);

// Advance to next K-dimension tile
++iter_V;
iter_V.load(frag_V);  // Load next chunk

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment