Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Deepspeedai DeepSpeed Evoformer Tile Iterator Residual

From Leeroopedia
Revision as of 14:46, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Deepspeedai_DeepSpeed_Evoformer_Tile_Iterator_Residual.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Attention, CUTLASS_Kernels, DeepSpeed4Science
Last Updated 2026-02-09 00:00 GMT

Overview

A global memory tile iterator with optimized residual handling that minimizes predication overhead by deferring bounds checks to the final tile access.

Description

PredicatedTileIteratorResidualLast provides efficient loading and storing of matrix tiles from global memory with intelligent handling of partial tiles at problem boundaries. Unlike naive predicated iterators that check bounds on every access, this implementation employs a "residual last" strategy where full interior tiles are accessed without predicates, and only the final residual tile performs bounds checking. This optimization significantly reduces instruction overhead for the common case of complete tiles. The iterator supports various memory layouts (RowMajor, ColumnMajor, affine transformations), vectorized memory accesses with configurable element counts, and optional gather/scatter patterns. Template parameters control advancement direction, alignment requirements, and memory access granularity.

Usage

This iterator is the primary mechanism for loading input tensors (Q, K, V) and storing output tensors in Evoformer attention kernels, providing high performance for arbitrary sequence lengths while correctly handling non-tile-aligned problem sizes.

Code Reference

Source Location

Signature

template <typename Shape_,                // Tile shape (MatrixShape)
          typename Element_,              // Element type
          typename Layout_,               // Memory layout
          int AdvanceRank,                // 0=column, 1=row advancement
          typename ThreadMap_,            // Thread to element mapping
          typename AccessType_,           // Vectorized access type
          bool Gather = false>            // Gather mode
class PredicatedTileIteratorResidualLast {
public:
    using Shape = Shape_;
    using Element = Element_;
    using Layout = Layout_;
    using TensorRef = TensorRef<Element, Layout>;

    static int const kAdvanceRank = AdvanceRank;
    static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;

    using Fragment = Array<Element, ThreadMap::Iterations::kCount *
                                   ThreadMap::kElementsPerAccess>;

    struct Params {
        Layout layout;
        CUTLASS_HOST_DEVICE Params(Layout const& layout_);
    };

    CUTLASS_HOST_DEVICE
    PredicatedTileIteratorResidualLast(Params const& params,
                                       Element* pointer,
                                       TensorCoord extent,
                                       int thread_id,
                                       TensorCoord const& threadblock_offset);

    CUTLASS_DEVICE void load(Fragment& frag);
    CUTLASS_DEVICE void store(Fragment const& frag);
    CUTLASS_DEVICE void operator++();
};

Import

#include "csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_iterator_residual_last.h"

I/O Contract

Parameter Type Description
Constructor Inputs
layout Layout Stride information for tensor
pointer Element* Base pointer to global memory
extent TensorCoord Logical dimensions (rows, cols)
thread_id int Thread index in threadblock
threadblock_offset TensorCoord Starting position for this block
Methods
load(frag) void Load tile into fragment with residual handling
store(frag) void Store fragment to global memory with residual handling
operator++() Iterator& Advance to next tile position

Usage Examples

// Configure iterator for loading queries with 128-bit aligned accesses
using QueryIterator = cutlass::transform::threadblock::PredicatedTileIteratorResidualLast<
    cutlass::MatrixShape<64, 64>,                     // 64x64 tile
    cutlass::half_t,                                  // FP16 elements
    cutlass::layout::RowMajor,                        // Row-major layout
    0,                                                // Advance along columns
    ThreadMap,                                        // Thread mapping
    cutlass::AlignedArray<cutlass::half_t, 8>        // 8-element vectors (128-bit)
>;

typename QueryIterator::Params params(layout);
QueryIterator iter(params, query_ptr, {seq_len, head_dim}, thread_id, {block_row, 0});

// Load multiple tiles efficiently
for (int tile_idx = 0; tile_idx < num_tiles; ++tile_idx) {
    typename QueryIterator::Fragment frag;
    iter.load(frag);  // Predicates only checked on last tile if seq_len % 64 != 0
    ++iter;           // Move to next tile

    // Process fragment...
}

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment