Implementation:Deepspeedai DeepSpeed Evoformer Kernel Backward
| Knowledge Sources | |
|---|---|
| Domains | Attention, CUTLASS_Kernels, DeepSpeed4Science |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
The CUDA kernel implementing backward pass gradient computation for memory-efficient Evoformer attention, computing dQ, dK, dV, and optional bias gradients.
Description
AttentionBackwardKernel implements the full backward pass for fused attention using a series of coordinated GEMM operations with custom epilogues. The kernel recomputes attention scores from saved Q/K rather than storing the full attention matrix, achieving O(1) memory complexity. It processes gradients in tiled blocks, computing: (1) dV = P^T × dO where P is recomputed attention weights, (2) dP = dO × V^T, (3) dS = softmax_backward(dP, P), (4) dQ = dS × K, and (5) dK = dS^T × Q. The implementation uses GmemTile for efficient register-to-memory transfers without epilogue overhead, shared memory staging for V and dO to enable reuse, and optional atomic accumulation for bias gradients. Architecture-specific optimizations select appropriate MMA operators (Volta vs Ampere) and configure block sizes based on available shared memory and register file constraints.
Usage
This kernel is launched during the backward pass of Evoformer attention layers in DeepSpeed4Science, computing all necessary gradients for training AlphaFold2-style models with memory efficiency suitable for long sequence lengths.
Code Reference
Source Location
- Repository: DeepSpeed
- File: csrc/deepspeed4science/evoformer_attn/kernel_backward.h
Signature
template <
typename ArchTag_, // Architecture (Sm70, Sm75, Sm80)
typename scalar_t_, // Element type (half_t, bfloat16_t)
bool kIsAligned_, // Memory alignment
bool kApplyDropout_, // Apply dropout
bool kPreload_, // Preload next GEMM operands
int kBlockSizeI_, // Block size in query dimension
int kBlockSizeJ_, // Block size in key dimension
int kMaxK_, // Max head dimension
template <typename, typename, typename> class Broadcast1_,
template <typename, typename, typename> class Broadcast2_>
struct AttentionBackwardKernel {
using scalar_t = scalar_t_;
using accum_t = float;
static constexpr int kNumThreads = kBlockSizeI * kBlockSizeJ / 64;
static constexpr int kMinBlocksPerSm = getWarpsPerSm<scalar_t, ArchTag>() / (kNumThreads / 32);
struct Params {
// Input tensors
scalar_t *query_ptr, *key_ptr, *value_ptr;
scalar_t *grad_output_ptr;
accum_t *logsumexp_ptr;
// Output gradient tensors
scalar_t *grad_query_ptr, *grad_key_ptr, *grad_value_ptr;
scalar_t *grad_bias1_ptr, *grad_bias2_ptr;
// Dimensions and strides
int32_t head_dim, head_dim_value;
int32_t num_queries, num_keys;
int32_t q_strideM, k_strideM, v_strideM;
// ... additional strides
accum_t scale;
};
static CUTLASS_DEVICE void attention_kernel(Params const& p);
};
Import
#include "csrc/deepspeed4science/evoformer_attn/kernel_backward.h"
I/O Contract
| Tensor | Shape | Type | Description |
|---|---|---|---|
| Inputs | |||
| query | [B, H, N, D] | scalar_t | Query tensor from forward pass |
| key | [B, H, M, D] | scalar_t | Key tensor from forward pass |
| value | [B, H, M, Dv] | scalar_t | Value tensor from forward pass |
| grad_output | [B, H, N, Dv] | scalar_t | Gradient w.r.t. output |
| logsumexp | [B, H, N] | float | Saved logsumexp from forward |
| Outputs | |||
| grad_query | [B, H, N, D] | scalar_t | Gradient w.r.t. queries |
| grad_key | [B, H, M, D] | scalar_t | Gradient w.r.t. keys |
| grad_value | [B, H, M, Dv] | scalar_t | Gradient w.r.t. values |
| grad_bias1/2 | varies | scalar_t | Optional bias gradients (atomic) |
Usage Examples
// Configure backward kernel for Ampere with FP16
using BackwardKernel = AttentionBackwardKernel<
cutlass::arch::Sm80, // Ampere architecture
cutlass::half_t, // FP16
true, // Aligned accesses
false, // No dropout
true, // Preload enabled
64, // kBlockSizeI (queries per block)
64, // kBlockSizeJ (keys per block)
128, // kMaxK (max head dimension)
BroadcastNoLoad, // No bias gradients
BroadcastNoLoad
>;
// Launch kernel
dim3 grid(num_query_blocks, num_heads, batch_size);
dim3 block(BackwardKernel::kNumThreads);
typename BackwardKernel::Params params;
// ... initialize params
BackwardKernel::attention_kernel<<<grid, block, 0, stream>>>(params);