Implementation:Deepspeedai DeepSpeed Evoformer Tile Iterator Residual
| Knowledge Sources | |
|---|---|
| Domains | Attention, CUTLASS_Kernels, DeepSpeed4Science |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
A global memory tile iterator with optimized residual handling that minimizes predication overhead by deferring bounds checks to the final tile access.
Description
PredicatedTileIteratorResidualLast provides efficient loading and storing of matrix tiles from global memory with intelligent handling of partial tiles at problem boundaries. Unlike naive predicated iterators that check bounds on every access, this implementation employs a "residual last" strategy where full interior tiles are accessed without predicates, and only the final residual tile performs bounds checking. This optimization significantly reduces instruction overhead for the common case of complete tiles. The iterator supports various memory layouts (RowMajor, ColumnMajor, affine transformations), vectorized memory accesses with configurable element counts, and optional gather/scatter patterns. Template parameters control advancement direction, alignment requirements, and memory access granularity.
Usage
This iterator is the primary mechanism for loading input tensors (Q, K, V) and storing output tensors in Evoformer attention kernels, providing high performance for arbitrary sequence lengths while correctly handling non-tile-aligned problem sizes.
Code Reference
Source Location
- Repository: DeepSpeed
- File: csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_iterator_residual_last.h
Signature
template <typename Shape_, // Tile shape (MatrixShape)
typename Element_, // Element type
typename Layout_, // Memory layout
int AdvanceRank, // 0=column, 1=row advancement
typename ThreadMap_, // Thread to element mapping
typename AccessType_, // Vectorized access type
bool Gather = false> // Gather mode
class PredicatedTileIteratorResidualLast {
public:
using Shape = Shape_;
using Element = Element_;
using Layout = Layout_;
using TensorRef = TensorRef<Element, Layout>;
static int const kAdvanceRank = AdvanceRank;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
using Fragment = Array<Element, ThreadMap::Iterations::kCount *
ThreadMap::kElementsPerAccess>;
struct Params {
Layout layout;
CUTLASS_HOST_DEVICE Params(Layout const& layout_);
};
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(Params const& params,
Element* pointer,
TensorCoord extent,
int thread_id,
TensorCoord const& threadblock_offset);
CUTLASS_DEVICE void load(Fragment& frag);
CUTLASS_DEVICE void store(Fragment const& frag);
CUTLASS_DEVICE void operator++();
};
Import
#include "csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_iterator_residual_last.h"
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| Constructor Inputs | ||
| layout | Layout | Stride information for tensor |
| pointer | Element* | Base pointer to global memory |
| extent | TensorCoord | Logical dimensions (rows, cols) |
| thread_id | int | Thread index in threadblock |
| threadblock_offset | TensorCoord | Starting position for this block |
| Methods | ||
| load(frag) | void | Load tile into fragment with residual handling |
| store(frag) | void | Store fragment to global memory with residual handling |
| operator++() | Iterator& | Advance to next tile position |
Usage Examples
// Configure iterator for loading queries with 128-bit aligned accesses
using QueryIterator = cutlass::transform::threadblock::PredicatedTileIteratorResidualLast<
cutlass::MatrixShape<64, 64>, // 64x64 tile
cutlass::half_t, // FP16 elements
cutlass::layout::RowMajor, // Row-major layout
0, // Advance along columns
ThreadMap, // Thread mapping
cutlass::AlignedArray<cutlass::half_t, 8> // 8-element vectors (128-bit)
>;
typename QueryIterator::Params params(layout);
QueryIterator iter(params, query_ptr, {seq_len, head_dim}, thread_id, {block_row, 0});
// Load multiple tiles efficiently
for (int tile_idx = 0; tile_idx < num_tiles; ++tile_idx) {
typename QueryIterator::Fragment frag;
iter.load(frag); // Predicates only checked on last tile if seq_len % 64 != 0
++iter; // Move to next tile
// Process fragment...
}