Implementation:FMInference FlexLLMGen DeepSpeed Inference CuBLAS Wrappers
| Knowledge Sources | |
|---|---|
| Domains | CUDA, cuBLAS, Linear Algebra, Deep Learning Inference |
| Last Updated | 2026-02-09 12:00 GMT |
Overview
A header-only wrapper library providing simplified C++ interfaces to cuBLAS GEMM and strided batched GEMM operations for both FP32 and FP16 data types, with cross-platform support for NVIDIA CUDA and AMD ROCm.
Description
This file provides four overloaded wrapper functions that abstract the verbose cuBLAS/rocBLAS extended GEMM APIs into a consistent, simplified interface:
- cublas_gemm_ex (float version): Wraps cublasGemmEx for single-precision general matrix multiplication (C = alpha * op(A) * op(B) + beta * C) using CUDA_R_32F compute type.
- cublas_gemm_ex (__half version): Wraps cublasGemmEx for half-precision GEMM using CUDA_R_16F data type with CUDA_R_32F compute type for numerical stability (mixed-precision accumulation).
- cublas_strided_batched_gemm (float version): Wraps cublasGemmStridedBatchedEx for batched GEMM on strided tensors in single precision.
- cublas_strided_batched_gemm (__half version): Wraps cublasGemmStridedBatchedEx for batched half-precision GEMM with FP32 accumulation.
Each function handles the leading dimension computation automatically based on the transpose operation (CUBLAS_OP_N vs CUBLAS_OP_T), provides error reporting via stderr on failure, and supports platform-specific dispatch between NVIDIA cuBLAS and AMD rocBLAS using #ifdef __HIP_PLATFORM_HCC__ preprocessor guards.
The functions accept a GEMM algorithm selector parameter (cublasGemmAlgo_t on CUDA, rocblas_gemm_algo on ROCm) enabling callers to specify CUBLAS_GEMM_DEFAULT_TENSOR_OP for Tensor Core acceleration or other algorithm variants for tuning.
Usage
These wrappers are called from the inference PyTorch binding layer (pt_binding.cpp) and other DeepSpeed CUDA C++ code wherever matrix multiplication is needed, including QKV projections, attention score computation, feed-forward layers, and output projections.
Code Reference
Source Location
- Repository: FMInference_FlexLLMGen
- File: benchmark/third_party/DeepSpeed/csrc/transformer/inference/includes/inference_cublas_wrappers.h
- Lines: 1-417
Signature
// Single GEMM: C = alpha * op(A) * op(B) + beta * C
int cublas_gemm_ex(cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k,
const float* alpha, const float* beta,
const float* A, const float* B, float* C,
cublasGemmAlgo_t algo);
int cublas_gemm_ex(cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k,
const float* alpha, const float* beta,
const __half* A, const __half* B, __half* C,
cublasGemmAlgo_t algo);
// Strided batched GEMM
int cublas_strided_batched_gemm(cublasHandle_t handle,
int m, int n, int k,
const float* alpha, const float* beta,
const float* A, const float* B, float* C,
cublasOperation_t op_A, cublasOperation_t op_B,
int stride_A, int stride_B, int stride_C,
int batch, cublasGemmAlgo_t algo);
int cublas_strided_batched_gemm(cublasHandle_t handle,
int m, int n, int k,
const float* alpha, const float* beta,
const __half* A, const __half* B, __half* C,
cublasOperation_t op_A, cublasOperation_t op_B,
int stride_A, int stride_B, int stride_C,
int batch, cublasGemmAlgo_t algo);
Import
#include "inference_cublas_wrappers.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| handle | cublasHandle_t | Yes | cuBLAS library handle, must have the correct stream set. |
| transa/op_A | cublasOperation_t | Yes | Transpose operation for matrix A: CUBLAS_OP_N (no transpose) or CUBLAS_OP_T (transpose). |
| transb/op_B | cublasOperation_t | Yes | Transpose operation for matrix B. |
| m, n, k | int | Yes | Matrix dimensions: A is (m x k) after transpose, B is (k x n) after transpose, C is (m x n). |
| alpha, beta | const float* | Yes | Scalar multipliers for the GEMM operation. |
| A, B | const T* | Yes | Input matrix pointers in device memory. |
| C | T* | Yes | Output matrix pointer in device memory. |
| algo | cublasGemmAlgo_t | Yes | GEMM algorithm selector (e.g., CUBLAS_GEMM_DEFAULT_TENSOR_OP). |
| stride_A, stride_B, stride_C | int | Batched only | Strides between consecutive matrices in the batch. |
| batch | int | Batched only | Number of matrices in the batch. |
Outputs
| Name | Type | Description |
|---|---|---|
| return value | int | 0 on success, EXIT_FAILURE on error. |
| C | T* | Output matrix written in-place: C = alpha * op(A) * op(B) + beta * C. |
Usage Examples
// FP16 GEMM for QKV projection
float alpha = 1.0f;
float beta = 0.0f;
cublas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
3 * hidden_size, // m: output columns
batch * seq_len, // n: batch dimension
hidden_size, // k: input dimension
&alpha, &beta,
(__half*)qkv_weight,
(__half*)input,
(__half*)qkv_output,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
// Batched GEMM for attention scores: Q * K^T
float scale = 1.0f / sqrtf(head_dim);
cublas_strided_batched_gemm(handle,
seq_len, seq_len, head_dim,
&scale, &beta,
(__half*)key, (__half*)query, (__half*)attn_scores,
CUBLAS_OP_T, CUBLAS_OP_N,
seq_len * head_dim, seq_len * head_dim,
seq_len * seq_len,
batch * num_heads,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);