Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Deepspeedai DeepSpeed Evoformer Kernel Forward

From Leeroopedia
Revision as of 14:46, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Deepspeedai_DeepSpeed_Evoformer_Kernel_Forward.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Attention, CUTLASS_Kernels, DeepSpeed4Science
Last Updated 2026-02-09 00:00 GMT

Overview

The CUDA kernel implementing memory-efficient forward pass for Evoformer attention, computing softmax(QK^T/√d)V with O(1) memory overhead.

Description

AttentionKernel implements fused attention forward pass using a tiled algorithm that avoids materializing the full N×M attention matrix. The kernel processes queries in blocks of kQueriesPerBlock and keys in blocks of kKeysPerBlock, computing attention scores on-the-fly and immediately applying them to values. For each query block, it iterates over all key blocks computing: (1) S = Q × K^T (scaled), (2) online softmax using logsumexp normalization with incremental max/sum tracking, (3) P = exp(S - max), (4) accumulation of O += P × V with renormalization. The implementation maintains running statistics (m_prime, s_prime) for numerically stable softmax across tiles and stores final logsumexp values for backward pass. Architecture-specific optimizations include preloading V matrix into shared memory on Ampere, single-value-iteration mode when head dimension fits in registers, and optional bias broadcasting support for position/pair embeddings.

Usage

This kernel is the primary forward pass implementation for Evoformer attention in DeepSpeed4Science, enabling training of AlphaFold2-style models on long sequences by reducing attention memory from O(N²) to O(N).

Code Reference

Source Location

Signature

template <
    typename scalar_t_,                  // Element type (half_t, bfloat16_t)
    typename ArchTag,                    // Architecture (Sm70, Sm75, Sm80)
    bool isAligned_,                     // Memory alignment
    int kQueriesPerBlock,                // Queries per threadblock
    int kKeysPerBlock_,                  // Keys per threadblock
    bool kSingleValueIteration_,         // Value dimension fits in RF
    bool kSupportsBias_ = false,         // Support bias terms
    template <typename, typename, typename> class Broadcast1_ = BroadcastNoLoad,
    template <typename, typename, typename> class Broadcast2_ = BroadcastNoLoad>
struct AttentionKernel {
    using scalar_t = scalar_t_;
    using accum_t = float;
    using output_t = scalar_t;

    static constexpr int kNumThreads = kQueriesPerBlock * kKeysPerBlock / 1024;
    static constexpr int kMinBlocksPerSm = getWarpsPerSm<scalar_t, ArchTag>() / (kNumThreads / 32);

    struct Params {
        // Input tensors
        scalar_t* query_ptr;   // [B, H, N, D]
        scalar_t* key_ptr;     // [B, H, M, D]
        scalar_t* value_ptr;   // [B, H, M, Dv]

        // Output tensors
        output_t* output_ptr;              // [B, H, N, Dv]
        output_accum_t* output_accum_ptr;  // Intermediate accumulator
        lse_scalar_t* logsumexp_ptr;       // [B, H, N] logsumexp values

        // Scale and dimensions
        accum_t scale;
        int32_t head_dim, head_dim_value;
        int32_t num_queries, num_keys;

        // Strides
        int32_t q_strideM, k_strideM, v_strideM, o_strideM;
        int32_t q_strideH, k_strideH, v_strideH;
        int64_t q_strideB, k_strideB, v_strideB;

        // Bias pointers (optional)
        scalar_t *bias1_ptr, *bias2_ptr;

        CUTLASS_DEVICE bool advance_to_block();
    };

    static CUTLASS_DEVICE void attention_kernel(Params& p);
};

Import

#include "csrc/deepspeed4science/evoformer_attn/kernel_forward.h"

I/O Contract

Tensor Shape Type Description
Inputs
query [B, H, N, D] scalar_t Query tensor
key [B, H, M, D] scalar_t Key tensor
value [B, H, M, Dv] scalar_t Value tensor
bias1 [B, N, 1, 1, M] scalar_t Optional row-broadcast bias
bias2 [B, 1, H, M, M] scalar_t Optional matrix bias
Outputs
output [B, H, N, Dv] scalar_t Attention output
logsumexp [B, H, N] float Logsumexp statistics for backward
Configuration
scale float Scaling factor (typically 1/√D)
kQueriesPerBlock int Tile size for queries (32-128)
kKeysPerBlock int Tile size for keys (32-128)

Usage Examples

// Configure forward kernel for Ampere with 64x64 tiles
using ForwardKernel = AttentionKernel<
    cutlass::half_t,           // FP16
    cutlass::arch::Sm80,       // Ampere
    true,                      // Aligned
    64,                        // Queries per block
    64,                        // Keys per block
    false,                     // Multi-iteration for large Dv
    true,                      // Support bias
    BroadcastA,                // Row-wise bias
    BroadcastB                 // Matrix bias
>;

// Launch configuration
dim3 grid(
    (num_queries + 64 - 1) / 64,  // Query blocks
    num_heads,
    batch_size
);
dim3 block(ForwardKernel::kNumThreads);

typename ForwardKernel::Params params;
params.query_ptr = query;
params.key_ptr = key;
params.value_ptr = value;
params.output_ptr = output;
params.logsumexp_ptr = logsumexp;
params.scale = 1.0f / sqrtf(head_dim);
params.head_dim = head_dim;
params.num_queries = seq_len;
params.num_keys = seq_len;
// ... set strides

ForwardKernel::attention_kernel<<<grid, block, 0, stream>>>(params);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment