Implementation:Deepspeedai DeepSpeed Evoformer GEMM Utils
| Knowledge Sources | |
|---|---|
| Domains | Attention, CUTLASS_Kernels, DeepSpeed4Science |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Utility functions, macros, and type traits for configuring CUTLASS GEMM kernels across different GPU architectures and data types in the Evoformer attention implementation.
Description
gemm_kernel_utils.h provides the foundational utilities for architecture-aware GEMM configuration in Evoformer attention. DefaultGemmType is a template metafunction that selects appropriate instruction shapes, warp configurations, and operation classes based on architecture (Volta/Turing/Ampere) and data type (FP32/FP16/BF16). It determines whether to use SIMT (CUDA cores) or TensorOp (tensor cores), sets optimal ThreadK/WarpK parameters, and specifies minimum alignment requirements. CheckArch validates architecture compatibility with data types. Helper functions include ceil_div and align_up for dimension calculations, and warp_uniform for marking variables as warp-uniform to enable compiler optimizations. The file also defines preprocessor macros DISPATCH_ARCHTAG, DISPATCH_TYPES, and DISPATCH_BOOL for compile-time and runtime kernel specialization.
Usage
This header is included by all Evoformer attention kernel implementations to configure architecture-specific parameters, perform type dispatching, and access utility functions for dimension calculations and alignment operations.
Code Reference
Source Location
- Repository: DeepSpeed
- File: csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h
Signature
namespace gemm_kernel_utils {
// Default GEMM configuration selector
template <typename ArchTag, typename scalar_t_, typename Enable = void>
struct DefaultGemmType {
static constexpr int ThreadK = 8;
static constexpr int WarpK = 8;
static constexpr int kMinimumAlignment = 1;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
using OpClass = cutlass::arch::OpClassSimt;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// Specialization for Ampere FP32 tensor cores
template <typename ArchTag>
struct DefaultGemmType<ArchTag, float, /* enable_if SM80+ */> {
static constexpr int ThreadK = 32;
static constexpr int WarpK = 32;
static constexpr int kMinimumAlignment = 4;
using OpClass = cutlass::arch::OpClassTensorOp;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using Operator = cutlass::arch::OpMultiplyAddFastF32;
};
// Architecture compatibility checker
template <typename arch, typename scalar_t>
struct CheckArch {
static constexpr bool value = /* architecture supports scalar_t */;
};
// Utility functions
template <typename integer>
constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m);
template <typename integer>
constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m);
CUTLASS_DEVICE int32_t warp_uniform(int32_t value);
template <typename T>
CUTLASS_DEVICE T* warp_uniform(T* ptr);
} // namespace gemm_kernel_utils
Import
#include "csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h"
I/O Contract
| Component | Input | Output | Description |
|---|---|---|---|
| DefaultGemmType | ArchTag, scalar_t | Type traits | Selects optimal GEMM configuration for architecture/type |
| ceil_div | n, m (integers) | (n + m - 1) / m | Ceiling division |
| align_up | n, m (integers) | Aligned value | Round up to next multiple of m |
| warp_uniform | value/ptr | Broadcast value | Mark as warp-uniform (broadcast from lane 0) |
| DISPATCH_ARCHTAG | CC (compute capability) | ArchTag typedef | Dispatch to architecture-specific code path |
| DISPATCH_TYPES | tensor (at::Tensor) | scalar_t typedef | Dispatch to FP16/BF16 code path |
Usage Examples
// Select GEMM configuration based on architecture
using GemmConfig = gemm_kernel_utils::DefaultGemmType<cutlass::arch::Sm80, cutlass::half_t>;
static_assert(GemmConfig::ThreadK == 32, "Ampere uses 32-wide K");
static_assert(GemmConfig::kMinimumAlignment == 4, "128-bit alignment");
using InstructionShape = typename GemmConfig::InstructionShape; // 16x8x8
using OpClass = typename GemmConfig::OpClass; // TensorOp
// Use utility functions
int num_tiles = gemm_kernel_utils::ceil_div(seq_len, kTileSize);
int aligned_dim = gemm_kernel_utils::align_up(head_dim, 8);
// Mark loop-invariant pointer as warp-uniform for optimization
__shared__ float smem_buffer[1024];
float* uniform_ptr = gemm_kernel_utils::warp_uniform(smem_buffer);
// Dispatch to architecture-specific kernel
DISPATCH_ARCHTAG(compute_capability, {
// ArchTag is now defined (Sm70, Sm75, or Sm80)
launch_kernel<ArchTag>(args);
});
// Dispatch based on tensor data type
DISPATCH_TYPES(query_tensor, {
// scalar_t is now cutlass::half_t or cutlass::bfloat16_t
run_attention<scalar_t>(args);
});