Principle:FMInference FlexLLMGen CuBLAS Inference Wrappers
| Knowledge Sources | |
|---|---|
| Domains | Linear Algebra, GPU Computing, API Design, Mixed Precision |
| Last Updated | 2026-02-09 12:00 GMT |
Overview
Wrapping vendor-specific BLAS libraries behind a simplified interface enables portable, type-safe matrix multiplication across GPU platforms while abstracting away verbose API details and mixed-precision accumulation semantics.
Description
Modern GPU BLAS libraries (cuBLAS for NVIDIA, rocBLAS for AMD) provide highly optimized matrix multiplication routines, but their APIs are verbose and error-prone. A single GEMM call requires specifying data types, leading dimensions, compute types, and algorithm selectors. Wrapper functions address several concerns:
Leading dimension inference: The leading dimension (lda, ldb, ldc) depends on whether the matrix is transposed. For column-major storage (cuBLAS convention): if op(A) = A, then lda = m; if op(A) = A^T, then lda = k. The wrapper computes this automatically based on the transpose flag.
Mixed-precision accumulation: For half-precision (FP16) GEMM, accumulating the dot product in FP32 before converting back to FP16 significantly improves numerical accuracy. The wrappers configure CUDA_R_16F for input/output data types and CUDA_R_32F for the compute type, enabling Tensor Core utilization with minimal accuracy loss.
Tensor Core algorithm selection: The CUBLAS_GEMM_DEFAULT_TENSOR_OP algorithm hint instructs cuBLAS to use Tensor Cores (matrix multiply-accumulate units) when available. These provide 4-8x throughput improvement over standard CUDA cores for supported matrix sizes and data types.
Batched GEMM for attention: Transformer attention requires computing batch_size * num_heads independent matrix multiplications (Q*K^T and attn*V). Strided batched GEMM executes all these as a single API call, allowing the GPU to schedule work across streaming multiprocessors more efficiently than individual GEMM calls.
Cross-platform portability: Conditional compilation supports both cuBLAS (NVIDIA) and rocBLAS (AMD) backends with the same wrapper interface. The caller code remains unchanged while the preprocessor selects the correct vendor-specific API calls, data type enums, and error handling.
Usage
Apply this principle when building GPU-accelerated linear algebra pipelines that must work across GPU vendors, support multiple precision modes, or when the raw BLAS API verbosity would obscure the algorithmic intent of the calling code.
Theoretical Basis
GEMM (General Matrix Multiply) computes C = alpha * op(A) * op(B) + beta * C, where op() is either the identity or transpose. It is the most performance-critical operation in deep learning, dominating the compute time for linear layers, attention scores, and MLP projections.
Column-major storage is the default in cuBLAS (inherited from FORTRAN BLAS conventions). For a matrix with M rows and N columns stored column-major, the leading dimension is M, meaning consecutive elements in a column are contiguous in memory. Row-major frameworks (like PyTorch, which uses row-major C-style layout) must swap operand order or transpose flags when calling cuBLAS.
Mixed-precision computation exploits the observation that FP16 has sufficient precision for storing weights and activations (since they typically occupy a limited dynamic range), while FP32 is needed for the accumulation of partial sums in dot products (where rounding errors compound). Tensor Cores natively support this FP16-input-FP32-accumulate mode.
Strided batched GEMM assumes that the batch dimension is encoded as a fixed stride between consecutive matrices in device memory. The stride_A, stride_B, stride_C parameters allow a single kernel launch to process all batch elements, amortizing launch overhead and enabling better GPU utilization.
Error handling is critical at the BLAS wrapper level because matrix dimension mismatches or memory corruption produce silent numerical errors rather than crashes. The wrappers check the return status and print diagnostic information (m, n, k, error code) to stderr.