Implementation:FMInference FlexLLMGen DeepSpeed GEMM Test
| Knowledge Sources | |
|---|---|
| Domains | CUDA Programming, Linear Algebra, Performance Testing |
| Last Updated | 2026-02-09 12:00 GMT |
Overview
C++ template classes for benchmarking cuBLAS GEMM (General Matrix Multiply) operations, providing automated algorithm selection by exhaustive timing across all available cuBLAS algorithm variants.
Description
This header defines two template classes: GemmTest for standard matrix multiplications and StridedGemmTest for batched strided matrix multiplications. Both classes provide a TestAlgo method that systematically evaluates all cuBLAS GEMM algorithm variants (from CUBLAS_GEMM_DEFAULT_TENSOR_OP through CUBLAS_GEMM_ALGO15_TENSOR_OP) to find the fastest algorithm for a given matrix dimension configuration.
Each algorithm is tested through a three-phase process: a warm-up phase (5 iterations to fill caches and stabilize clocks), a timed phase (user-specified loop count with cudaDeviceSynchronize barriers for accurate timing), and result reporting (average latency in milliseconds per iteration). The testing covers three GEMM orientations that correspond to the forward pass, backward pass with respect to weights (bw1), and backward pass with respect to activations (bw2) in a neural network linear layer.
The implementation supports both NVIDIA CUDA (using cublasGemmAlgo_t) and AMD HIP (using rocblas_gemm_algo) via conditional compilation. GPU memory for matrices A, B, and C is allocated in the constructor and freed in the destructor via RAII.
Usage
Use these classes to find the optimal cuBLAS algorithm for specific matrix dimensions encountered in transformer model inference. The selected algorithms can then be hard-coded into production kernels for maximum throughput.
Code Reference
Source Location
- Repository: FMInference_FlexLLMGen
- File: benchmark/third_party/DeepSpeed/csrc/includes/gemm_test.h
- Lines: 1-327
Signature
template <typename T>
class GemmTest {
public:
GemmTest(int m, int n, int k,
cublasOperation_t ta, cublasOperation_t tb,
cublasHandle_t h);
~GemmTest();
std::array<int, 3> TestAlgo(int loops);
template <typename Func>
int Run(int loops, Func f);
private:
int M, N, K;
cublasHandle_t handle;
cublasOperation_t transa, transb;
T *A, *B, *C;
};
template <typename T>
class StridedGemmTest {
public:
StridedGemmTest(int b, int m, int n, int k,
cublasOperation_t ta, cublasOperation_t tb,
cublasHandle_t h);
~StridedGemmTest();
std::array<int, 3> TestAlgo(int loops);
template <typename Func>
int Run(int loops, Func f);
private:
int bsz, M, N, K;
cublasHandle_t handle;
cublasOperation_t transa, transb;
T *A, *B, *C;
};
Import
#include "gemm_test.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| m | int | Yes | Number of rows in matrix C (and A if not transposed) |
| n | int | Yes | Number of columns in matrix C (and B if not transposed) |
| k | int | Yes | Shared/inner dimension of the matrix multiplication |
| b | int | Yes (StridedGemmTest only) | Batch size for strided batched GEMM |
| ta | cublasOperation_t | Yes | Transpose operation for matrix A |
| tb | cublasOperation_t | Yes | Transpose operation for matrix B |
| h | cublasHandle_t | Yes | Active cuBLAS handle |
| loops | int | Yes | Number of timed iterations per algorithm |
Outputs
| Name | Type | Description |
|---|---|---|
| return value | std::array<int, 3> | Best algorithm indices for [forward, backward_weights, backward_activations] GEMM operations |
| stdout | text | Per-algorithm timing output (e.g., "algo-99: 0.432ms") and best algorithm summary |
Usage Examples
#include "gemm_test.h"
// Find optimal GEMM algorithms for a transformer layer with:
// hidden_size=768, batch_tokens=512
cublasHandle_t handle;
cublasCreate(&handle);
// Test standard GEMM for attention projection
GemmTest<__half> test(768, 768, 512, CUBLAS_OP_N, CUBLAS_OP_N, handle);
auto best_algos = test.TestAlgo(/*loops=*/100);
printf("Forward: algo %d, BW1: algo %d, BW2: algo %d\n",
best_algos[0], best_algos[1], best_algos[2]);
// Test strided batched GEMM for multi-head attention
// batch=12 heads, M=512 tokens, N=512 tokens, K=64 head_dim
StridedGemmTest<__half> strided_test(12, 512, 512, 64,
CUBLAS_OP_T, CUBLAS_OP_N, handle);
auto batched_algos = strided_test.TestAlgo(/*loops=*/100);
cublasDestroy(handle);