Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:FMInference FlexLLMGen DeepSpeed Inference CuBLAS Wrappers

From Leeroopedia


Knowledge Sources
Domains CUDA, cuBLAS, Linear Algebra, Deep Learning Inference
Last Updated 2026-02-09 12:00 GMT

Overview

A header-only wrapper library providing simplified C++ interfaces to cuBLAS GEMM and strided batched GEMM operations for both FP32 and FP16 data types, with cross-platform support for NVIDIA CUDA and AMD ROCm.

Description

This file provides four overloaded wrapper functions that abstract the verbose cuBLAS/rocBLAS extended GEMM APIs into a consistent, simplified interface:

  • cublas_gemm_ex (float version): Wraps cublasGemmEx for single-precision general matrix multiplication (C = alpha * op(A) * op(B) + beta * C) using CUDA_R_32F compute type.
  • cublas_gemm_ex (__half version): Wraps cublasGemmEx for half-precision GEMM using CUDA_R_16F data type with CUDA_R_32F compute type for numerical stability (mixed-precision accumulation).
  • cublas_strided_batched_gemm (float version): Wraps cublasGemmStridedBatchedEx for batched GEMM on strided tensors in single precision.
  • cublas_strided_batched_gemm (__half version): Wraps cublasGemmStridedBatchedEx for batched half-precision GEMM with FP32 accumulation.

Each function handles the leading dimension computation automatically based on the transpose operation (CUBLAS_OP_N vs CUBLAS_OP_T), provides error reporting via stderr on failure, and supports platform-specific dispatch between NVIDIA cuBLAS and AMD rocBLAS using #ifdef __HIP_PLATFORM_HCC__ preprocessor guards.

The functions accept a GEMM algorithm selector parameter (cublasGemmAlgo_t on CUDA, rocblas_gemm_algo on ROCm) enabling callers to specify CUBLAS_GEMM_DEFAULT_TENSOR_OP for Tensor Core acceleration or other algorithm variants for tuning.

Usage

These wrappers are called from the inference PyTorch binding layer (pt_binding.cpp) and other DeepSpeed CUDA C++ code wherever matrix multiplication is needed, including QKV projections, attention score computation, feed-forward layers, and output projections.

Code Reference

Source Location

Signature

// Single GEMM: C = alpha * op(A) * op(B) + beta * C
int cublas_gemm_ex(cublasHandle_t handle,
                   cublasOperation_t transa, cublasOperation_t transb,
                   int m, int n, int k,
                   const float* alpha, const float* beta,
                   const float* A, const float* B, float* C,
                   cublasGemmAlgo_t algo);

int cublas_gemm_ex(cublasHandle_t handle,
                   cublasOperation_t transa, cublasOperation_t transb,
                   int m, int n, int k,
                   const float* alpha, const float* beta,
                   const __half* A, const __half* B, __half* C,
                   cublasGemmAlgo_t algo);

// Strided batched GEMM
int cublas_strided_batched_gemm(cublasHandle_t handle,
                                int m, int n, int k,
                                const float* alpha, const float* beta,
                                const float* A, const float* B, float* C,
                                cublasOperation_t op_A, cublasOperation_t op_B,
                                int stride_A, int stride_B, int stride_C,
                                int batch, cublasGemmAlgo_t algo);

int cublas_strided_batched_gemm(cublasHandle_t handle,
                                int m, int n, int k,
                                const float* alpha, const float* beta,
                                const __half* A, const __half* B, __half* C,
                                cublasOperation_t op_A, cublasOperation_t op_B,
                                int stride_A, int stride_B, int stride_C,
                                int batch, cublasGemmAlgo_t algo);

Import

#include "inference_cublas_wrappers.h"

I/O Contract

Inputs

Name Type Required Description
handle cublasHandle_t Yes cuBLAS library handle, must have the correct stream set.
transa/op_A cublasOperation_t Yes Transpose operation for matrix A: CUBLAS_OP_N (no transpose) or CUBLAS_OP_T (transpose).
transb/op_B cublasOperation_t Yes Transpose operation for matrix B.
m, n, k int Yes Matrix dimensions: A is (m x k) after transpose, B is (k x n) after transpose, C is (m x n).
alpha, beta const float* Yes Scalar multipliers for the GEMM operation.
A, B const T* Yes Input matrix pointers in device memory.
C T* Yes Output matrix pointer in device memory.
algo cublasGemmAlgo_t Yes GEMM algorithm selector (e.g., CUBLAS_GEMM_DEFAULT_TENSOR_OP).
stride_A, stride_B, stride_C int Batched only Strides between consecutive matrices in the batch.
batch int Batched only Number of matrices in the batch.

Outputs

Name Type Description
return value int 0 on success, EXIT_FAILURE on error.
C T* Output matrix written in-place: C = alpha * op(A) * op(B) + beta * C.

Usage Examples

// FP16 GEMM for QKV projection
float alpha = 1.0f;
float beta = 0.0f;
cublas_gemm_ex(handle,
               CUBLAS_OP_N, CUBLAS_OP_N,
               3 * hidden_size,   // m: output columns
               batch * seq_len,   // n: batch dimension
               hidden_size,       // k: input dimension
               &alpha, &beta,
               (__half*)qkv_weight,
               (__half*)input,
               (__half*)qkv_output,
               CUBLAS_GEMM_DEFAULT_TENSOR_OP);

// Batched GEMM for attention scores: Q * K^T
float scale = 1.0f / sqrtf(head_dim);
cublas_strided_batched_gemm(handle,
                            seq_len, seq_len, head_dim,
                            &scale, &beta,
                            (__half*)key, (__half*)query, (__half*)attn_scores,
                            CUBLAS_OP_T, CUBLAS_OP_N,
                            seq_len * head_dim, seq_len * head_dim,
                            seq_len * seq_len,
                            batch * num_heads,
                            CUBLAS_GEMM_DEFAULT_TENSOR_OP);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment