Implementation:Deepspeedai DeepSpeed Evoformer Kernel Forward
| Knowledge Sources | |
|---|---|
| Domains | Attention, CUTLASS_Kernels, DeepSpeed4Science |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
The CUDA kernel implementing memory-efficient forward pass for Evoformer attention, computing softmax(QK^T/√d)V with O(1) memory overhead.
Description
AttentionKernel implements fused attention forward pass using a tiled algorithm that avoids materializing the full N×M attention matrix. The kernel processes queries in blocks of kQueriesPerBlock and keys in blocks of kKeysPerBlock, computing attention scores on-the-fly and immediately applying them to values. For each query block, it iterates over all key blocks computing: (1) S = Q × K^T (scaled), (2) online softmax using logsumexp normalization with incremental max/sum tracking, (3) P = exp(S - max), (4) accumulation of O += P × V with renormalization. The implementation maintains running statistics (m_prime, s_prime) for numerically stable softmax across tiles and stores final logsumexp values for backward pass. Architecture-specific optimizations include preloading V matrix into shared memory on Ampere, single-value-iteration mode when head dimension fits in registers, and optional bias broadcasting support for position/pair embeddings.
Usage
This kernel is the primary forward pass implementation for Evoformer attention in DeepSpeed4Science, enabling training of AlphaFold2-style models on long sequences by reducing attention memory from O(N²) to O(N).
Code Reference
Source Location
- Repository: DeepSpeed
- File: csrc/deepspeed4science/evoformer_attn/kernel_forward.h
Signature
template <
typename scalar_t_, // Element type (half_t, bfloat16_t)
typename ArchTag, // Architecture (Sm70, Sm75, Sm80)
bool isAligned_, // Memory alignment
int kQueriesPerBlock, // Queries per threadblock
int kKeysPerBlock_, // Keys per threadblock
bool kSingleValueIteration_, // Value dimension fits in RF
bool kSupportsBias_ = false, // Support bias terms
template <typename, typename, typename> class Broadcast1_ = BroadcastNoLoad,
template <typename, typename, typename> class Broadcast2_ = BroadcastNoLoad>
struct AttentionKernel {
using scalar_t = scalar_t_;
using accum_t = float;
using output_t = scalar_t;
static constexpr int kNumThreads = kQueriesPerBlock * kKeysPerBlock / 1024;
static constexpr int kMinBlocksPerSm = getWarpsPerSm<scalar_t, ArchTag>() / (kNumThreads / 32);
struct Params {
// Input tensors
scalar_t* query_ptr; // [B, H, N, D]
scalar_t* key_ptr; // [B, H, M, D]
scalar_t* value_ptr; // [B, H, M, Dv]
// Output tensors
output_t* output_ptr; // [B, H, N, Dv]
output_accum_t* output_accum_ptr; // Intermediate accumulator
lse_scalar_t* logsumexp_ptr; // [B, H, N] logsumexp values
// Scale and dimensions
accum_t scale;
int32_t head_dim, head_dim_value;
int32_t num_queries, num_keys;
// Strides
int32_t q_strideM, k_strideM, v_strideM, o_strideM;
int32_t q_strideH, k_strideH, v_strideH;
int64_t q_strideB, k_strideB, v_strideB;
// Bias pointers (optional)
scalar_t *bias1_ptr, *bias2_ptr;
CUTLASS_DEVICE bool advance_to_block();
};
static CUTLASS_DEVICE void attention_kernel(Params& p);
};
Import
#include "csrc/deepspeed4science/evoformer_attn/kernel_forward.h"
I/O Contract
| Tensor | Shape | Type | Description |
|---|---|---|---|
| Inputs | |||
| query | [B, H, N, D] | scalar_t | Query tensor |
| key | [B, H, M, D] | scalar_t | Key tensor |
| value | [B, H, M, Dv] | scalar_t | Value tensor |
| bias1 | [B, N, 1, 1, M] | scalar_t | Optional row-broadcast bias |
| bias2 | [B, 1, H, M, M] | scalar_t | Optional matrix bias |
| Outputs | |||
| output | [B, H, N, Dv] | scalar_t | Attention output |
| logsumexp | [B, H, N] | float | Logsumexp statistics for backward |
| Configuration | |||
| scale | float | Scaling factor (typically 1/√D) | |
| kQueriesPerBlock | int | Tile size for queries (32-128) | |
| kKeysPerBlock | int | Tile size for keys (32-128) |
Usage Examples
// Configure forward kernel for Ampere with 64x64 tiles
using ForwardKernel = AttentionKernel<
cutlass::half_t, // FP16
cutlass::arch::Sm80, // Ampere
true, // Aligned
64, // Queries per block
64, // Keys per block
false, // Multi-iteration for large Dv
true, // Support bias
BroadcastA, // Row-wise bias
BroadcastB // Matrix bias
>;
// Launch configuration
dim3 grid(
(num_queries + 64 - 1) / 64, // Query blocks
num_heads,
batch_size
);
dim3 block(ForwardKernel::kNumThreads);
typename ForwardKernel::Params params;
params.query_ptr = query;
params.key_ptr = key;
params.value_ptr = value;
params.output_ptr = output;
params.logsumexp_ptr = logsumexp;
params.scale = 1.0f / sqrtf(head_dim);
params.head_dim = head_dim;
params.num_queries = seq_len;
params.num_keys = seq_len;
// ... set strides
ForwardKernel::attention_kernel<<<grid, block, 0, stream>>>(params);