Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Deepspeedai DeepSpeed Evoformer GEMM Utils

From Leeroopedia


Knowledge Sources
Domains Attention, CUTLASS_Kernels, DeepSpeed4Science
Last Updated 2026-02-09 00:00 GMT

Overview

Utility functions, macros, and type traits for configuring CUTLASS GEMM kernels across different GPU architectures and data types in the Evoformer attention implementation.

Description

gemm_kernel_utils.h provides the foundational utilities for architecture-aware GEMM configuration in Evoformer attention. DefaultGemmType is a template metafunction that selects appropriate instruction shapes, warp configurations, and operation classes based on architecture (Volta/Turing/Ampere) and data type (FP32/FP16/BF16). It determines whether to use SIMT (CUDA cores) or TensorOp (tensor cores), sets optimal ThreadK/WarpK parameters, and specifies minimum alignment requirements. CheckArch validates architecture compatibility with data types. Helper functions include ceil_div and align_up for dimension calculations, and warp_uniform for marking variables as warp-uniform to enable compiler optimizations. The file also defines preprocessor macros DISPATCH_ARCHTAG, DISPATCH_TYPES, and DISPATCH_BOOL for compile-time and runtime kernel specialization.

Usage

This header is included by all Evoformer attention kernel implementations to configure architecture-specific parameters, perform type dispatching, and access utility functions for dimension calculations and alignment operations.

Code Reference

Source Location

Signature

namespace gemm_kernel_utils {

// Default GEMM configuration selector
template <typename ArchTag, typename scalar_t_, typename Enable = void>
struct DefaultGemmType {
    static constexpr int ThreadK = 8;
    static constexpr int WarpK = 8;
    static constexpr int kMinimumAlignment = 1;
    using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
    using OpClass = cutlass::arch::OpClassSimt;
    using Operator = cutlass::arch::OpMultiplyAdd;
};

// Specialization for Ampere FP32 tensor cores
template <typename ArchTag>
struct DefaultGemmType<ArchTag, float, /* enable_if SM80+ */> {
    static constexpr int ThreadK = 32;
    static constexpr int WarpK = 32;
    static constexpr int kMinimumAlignment = 4;
    using OpClass = cutlass::arch::OpClassTensorOp;
    using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
    using Operator = cutlass::arch::OpMultiplyAddFastF32;
};

// Architecture compatibility checker
template <typename arch, typename scalar_t>
struct CheckArch {
    static constexpr bool value = /* architecture supports scalar_t */;
};

// Utility functions
template <typename integer>
constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m);

template <typename integer>
constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m);

CUTLASS_DEVICE int32_t warp_uniform(int32_t value);
template <typename T>
CUTLASS_DEVICE T* warp_uniform(T* ptr);

}  // namespace gemm_kernel_utils

Import

#include "csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h"

I/O Contract

Component Input Output Description
DefaultGemmType ArchTag, scalar_t Type traits Selects optimal GEMM configuration for architecture/type
ceil_div n, m (integers) (n + m - 1) / m Ceiling division
align_up n, m (integers) Aligned value Round up to next multiple of m
warp_uniform value/ptr Broadcast value Mark as warp-uniform (broadcast from lane 0)
DISPATCH_ARCHTAG CC (compute capability) ArchTag typedef Dispatch to architecture-specific code path
DISPATCH_TYPES tensor (at::Tensor) scalar_t typedef Dispatch to FP16/BF16 code path

Usage Examples

// Select GEMM configuration based on architecture
using GemmConfig = gemm_kernel_utils::DefaultGemmType<cutlass::arch::Sm80, cutlass::half_t>;

static_assert(GemmConfig::ThreadK == 32, "Ampere uses 32-wide K");
static_assert(GemmConfig::kMinimumAlignment == 4, "128-bit alignment");

using InstructionShape = typename GemmConfig::InstructionShape;  // 16x8x8
using OpClass = typename GemmConfig::OpClass;                    // TensorOp

// Use utility functions
int num_tiles = gemm_kernel_utils::ceil_div(seq_len, kTileSize);
int aligned_dim = gemm_kernel_utils::align_up(head_dim, 8);

// Mark loop-invariant pointer as warp-uniform for optimization
__shared__ float smem_buffer[1024];
float* uniform_ptr = gemm_kernel_utils::warp_uniform(smem_buffer);

// Dispatch to architecture-specific kernel
DISPATCH_ARCHTAG(compute_capability, {
    // ArchTag is now defined (Sm70, Sm75, or Sm80)
    launch_kernel<ArchTag>(args);
});

// Dispatch based on tensor data type
DISPATCH_TYPES(query_tensor, {
    // scalar_t is now cutlass::half_t or cutlass::bfloat16_t
    run_attention<scalar_t>(args);
});

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment