Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Deepspeedai DeepSpeed Evoformer Kernel Backward

From Leeroopedia


Knowledge Sources
Domains Attention, CUTLASS_Kernels, DeepSpeed4Science
Last Updated 2026-02-09 00:00 GMT

Overview

The CUDA kernel implementing backward pass gradient computation for memory-efficient Evoformer attention, computing dQ, dK, dV, and optional bias gradients.

Description

AttentionBackwardKernel implements the full backward pass for fused attention using a series of coordinated GEMM operations with custom epilogues. The kernel recomputes attention scores from saved Q/K rather than storing the full attention matrix, achieving O(1) memory complexity. It processes gradients in tiled blocks, computing: (1) dV = P^T × dO where P is recomputed attention weights, (2) dP = dO × V^T, (3) dS = softmax_backward(dP, P), (4) dQ = dS × K, and (5) dK = dS^T × Q. The implementation uses GmemTile for efficient register-to-memory transfers without epilogue overhead, shared memory staging for V and dO to enable reuse, and optional atomic accumulation for bias gradients. Architecture-specific optimizations select appropriate MMA operators (Volta vs Ampere) and configure block sizes based on available shared memory and register file constraints.

Usage

This kernel is launched during the backward pass of Evoformer attention layers in DeepSpeed4Science, computing all necessary gradients for training AlphaFold2-style models with memory efficiency suitable for long sequence lengths.

Code Reference

Source Location

Signature

template <
    typename ArchTag_,              // Architecture (Sm70, Sm75, Sm80)
    typename scalar_t_,             // Element type (half_t, bfloat16_t)
    bool kIsAligned_,               // Memory alignment
    bool kApplyDropout_,            // Apply dropout
    bool kPreload_,                 // Preload next GEMM operands
    int kBlockSizeI_,               // Block size in query dimension
    int kBlockSizeJ_,               // Block size in key dimension
    int kMaxK_,                     // Max head dimension
    template <typename, typename, typename> class Broadcast1_,
    template <typename, typename, typename> class Broadcast2_>
struct AttentionBackwardKernel {
    using scalar_t = scalar_t_;
    using accum_t = float;

    static constexpr int kNumThreads = kBlockSizeI * kBlockSizeJ / 64;
    static constexpr int kMinBlocksPerSm = getWarpsPerSm<scalar_t, ArchTag>() / (kNumThreads / 32);

    struct Params {
        // Input tensors
        scalar_t *query_ptr, *key_ptr, *value_ptr;
        scalar_t *grad_output_ptr;
        accum_t *logsumexp_ptr;

        // Output gradient tensors
        scalar_t *grad_query_ptr, *grad_key_ptr, *grad_value_ptr;
        scalar_t *grad_bias1_ptr, *grad_bias2_ptr;

        // Dimensions and strides
        int32_t head_dim, head_dim_value;
        int32_t num_queries, num_keys;
        int32_t q_strideM, k_strideM, v_strideM;
        // ... additional strides

        accum_t scale;
    };

    static CUTLASS_DEVICE void attention_kernel(Params const& p);
};

Import

#include "csrc/deepspeed4science/evoformer_attn/kernel_backward.h"

I/O Contract

Tensor Shape Type Description
Inputs
query [B, H, N, D] scalar_t Query tensor from forward pass
key [B, H, M, D] scalar_t Key tensor from forward pass
value [B, H, M, Dv] scalar_t Value tensor from forward pass
grad_output [B, H, N, Dv] scalar_t Gradient w.r.t. output
logsumexp [B, H, N] float Saved logsumexp from forward
Outputs
grad_query [B, H, N, D] scalar_t Gradient w.r.t. queries
grad_key [B, H, M, D] scalar_t Gradient w.r.t. keys
grad_value [B, H, M, Dv] scalar_t Gradient w.r.t. values
grad_bias1/2 varies scalar_t Optional bias gradients (atomic)

Usage Examples

// Configure backward kernel for Ampere with FP16
using BackwardKernel = AttentionBackwardKernel<
    cutlass::arch::Sm80,        // Ampere architecture
    cutlass::half_t,             // FP16
    true,                        // Aligned accesses
    false,                       // No dropout
    true,                        // Preload enabled
    64,                          // kBlockSizeI (queries per block)
    64,                          // kBlockSizeJ (keys per block)
    128,                         // kMaxK (max head dimension)
    BroadcastNoLoad,             // No bias gradients
    BroadcastNoLoad
>;

// Launch kernel
dim3 grid(num_query_blocks, num_heads, batch_size);
dim3 block(BackwardKernel::kNumThreads);

typename BackwardKernel::Params params;
// ... initialize params

BackwardKernel::attention_kernel<<<grid, block, 0, stream>>>(params);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment