Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:FMInference FlexLLMGen DeepSpeed GEMM Test

From Leeroopedia


Knowledge Sources
Domains CUDA Programming, Linear Algebra, Performance Testing
Last Updated 2026-02-09 12:00 GMT

Overview

C++ template classes for benchmarking cuBLAS GEMM (General Matrix Multiply) operations, providing automated algorithm selection by exhaustive timing across all available cuBLAS algorithm variants.

Description

This header defines two template classes: GemmTest for standard matrix multiplications and StridedGemmTest for batched strided matrix multiplications. Both classes provide a TestAlgo method that systematically evaluates all cuBLAS GEMM algorithm variants (from CUBLAS_GEMM_DEFAULT_TENSOR_OP through CUBLAS_GEMM_ALGO15_TENSOR_OP) to find the fastest algorithm for a given matrix dimension configuration.

Each algorithm is tested through a three-phase process: a warm-up phase (5 iterations to fill caches and stabilize clocks), a timed phase (user-specified loop count with cudaDeviceSynchronize barriers for accurate timing), and result reporting (average latency in milliseconds per iteration). The testing covers three GEMM orientations that correspond to the forward pass, backward pass with respect to weights (bw1), and backward pass with respect to activations (bw2) in a neural network linear layer.

The implementation supports both NVIDIA CUDA (using cublasGemmAlgo_t) and AMD HIP (using rocblas_gemm_algo) via conditional compilation. GPU memory for matrices A, B, and C is allocated in the constructor and freed in the destructor via RAII.

Usage

Use these classes to find the optimal cuBLAS algorithm for specific matrix dimensions encountered in transformer model inference. The selected algorithms can then be hard-coded into production kernels for maximum throughput.

Code Reference

Source Location

Signature

template <typename T>
class GemmTest {
public:
    GemmTest(int m, int n, int k,
             cublasOperation_t ta, cublasOperation_t tb,
             cublasHandle_t h);
    ~GemmTest();
    std::array<int, 3> TestAlgo(int loops);

    template <typename Func>
    int Run(int loops, Func f);

private:
    int M, N, K;
    cublasHandle_t handle;
    cublasOperation_t transa, transb;
    T *A, *B, *C;
};

template <typename T>
class StridedGemmTest {
public:
    StridedGemmTest(int b, int m, int n, int k,
                    cublasOperation_t ta, cublasOperation_t tb,
                    cublasHandle_t h);
    ~StridedGemmTest();
    std::array<int, 3> TestAlgo(int loops);

    template <typename Func>
    int Run(int loops, Func f);

private:
    int bsz, M, N, K;
    cublasHandle_t handle;
    cublasOperation_t transa, transb;
    T *A, *B, *C;
};

Import

#include "gemm_test.h"

I/O Contract

Inputs

Name Type Required Description
m int Yes Number of rows in matrix C (and A if not transposed)
n int Yes Number of columns in matrix C (and B if not transposed)
k int Yes Shared/inner dimension of the matrix multiplication
b int Yes (StridedGemmTest only) Batch size for strided batched GEMM
ta cublasOperation_t Yes Transpose operation for matrix A
tb cublasOperation_t Yes Transpose operation for matrix B
h cublasHandle_t Yes Active cuBLAS handle
loops int Yes Number of timed iterations per algorithm

Outputs

Name Type Description
return value std::array<int, 3> Best algorithm indices for [forward, backward_weights, backward_activations] GEMM operations
stdout text Per-algorithm timing output (e.g., "algo-99: 0.432ms") and best algorithm summary

Usage Examples

#include "gemm_test.h"

// Find optimal GEMM algorithms for a transformer layer with:
// hidden_size=768, batch_tokens=512
cublasHandle_t handle;
cublasCreate(&handle);

// Test standard GEMM for attention projection
GemmTest<__half> test(768, 768, 512, CUBLAS_OP_N, CUBLAS_OP_N, handle);
auto best_algos = test.TestAlgo(/*loops=*/100);
printf("Forward: algo %d, BW1: algo %d, BW2: algo %d\n",
       best_algos[0], best_algos[1], best_algos[2]);

// Test strided batched GEMM for multi-head attention
// batch=12 heads, M=512 tokens, N=512 tokens, K=64 head_dim
StridedGemmTest<__half> strided_test(12, 512, 512, 64,
                                      CUBLAS_OP_T, CUBLAS_OP_N, handle);
auto batched_algos = strided_test.TestAlgo(/*loops=*/100);

cublasDestroy(handle);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment