Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Principle:FMInference FlexLLMGen CuBLAS Inference Wrappers

From Leeroopedia


Knowledge Sources
Domains Linear Algebra, GPU Computing, API Design, Mixed Precision
Last Updated 2026-02-09 12:00 GMT

Overview

Wrapping vendor-specific BLAS libraries behind a simplified interface enables portable, type-safe matrix multiplication across GPU platforms while abstracting away verbose API details and mixed-precision accumulation semantics.

Description

Modern GPU BLAS libraries (cuBLAS for NVIDIA, rocBLAS for AMD) provide highly optimized matrix multiplication routines, but their APIs are verbose and error-prone. A single GEMM call requires specifying data types, leading dimensions, compute types, and algorithm selectors. Wrapper functions address several concerns:

Leading dimension inference: The leading dimension (lda, ldb, ldc) depends on whether the matrix is transposed. For column-major storage (cuBLAS convention): if op(A) = A, then lda = m; if op(A) = A^T, then lda = k. The wrapper computes this automatically based on the transpose flag.

Mixed-precision accumulation: For half-precision (FP16) GEMM, accumulating the dot product in FP32 before converting back to FP16 significantly improves numerical accuracy. The wrappers configure CUDA_R_16F for input/output data types and CUDA_R_32F for the compute type, enabling Tensor Core utilization with minimal accuracy loss.

Tensor Core algorithm selection: The CUBLAS_GEMM_DEFAULT_TENSOR_OP algorithm hint instructs cuBLAS to use Tensor Cores (matrix multiply-accumulate units) when available. These provide 4-8x throughput improvement over standard CUDA cores for supported matrix sizes and data types.

Batched GEMM for attention: Transformer attention requires computing batch_size * num_heads independent matrix multiplications (Q*K^T and attn*V). Strided batched GEMM executes all these as a single API call, allowing the GPU to schedule work across streaming multiprocessors more efficiently than individual GEMM calls.

Cross-platform portability: Conditional compilation supports both cuBLAS (NVIDIA) and rocBLAS (AMD) backends with the same wrapper interface. The caller code remains unchanged while the preprocessor selects the correct vendor-specific API calls, data type enums, and error handling.

Usage

Apply this principle when building GPU-accelerated linear algebra pipelines that must work across GPU vendors, support multiple precision modes, or when the raw BLAS API verbosity would obscure the algorithmic intent of the calling code.

Theoretical Basis

GEMM (General Matrix Multiply) computes C = alpha * op(A) * op(B) + beta * C, where op() is either the identity or transpose. It is the most performance-critical operation in deep learning, dominating the compute time for linear layers, attention scores, and MLP projections.

Column-major storage is the default in cuBLAS (inherited from FORTRAN BLAS conventions). For a matrix with M rows and N columns stored column-major, the leading dimension is M, meaning consecutive elements in a column are contiguous in memory. Row-major frameworks (like PyTorch, which uses row-major C-style layout) must swap operand order or transpose flags when calling cuBLAS.

Mixed-precision computation exploits the observation that FP16 has sufficient precision for storing weights and activations (since they typically occupy a limited dynamic range), while FP32 is needed for the accumulation of partial sums in dot products (where rounding errors compound). Tensor Cores natively support this FP16-input-FP32-accumulate mode.

Strided batched GEMM assumes that the batch dimension is encoded as a fixed stride between consecutive matrices in device memory. The stride_A, stride_B, stride_C parameters allow a single kernel launch to process all batch elements, amortizing launch overhead and enabling better GPU utilization.

Error handling is critical at the BLAS wrapper level because matrix dimension mismatches or memory corruption produce silent numerical errors rather than crashes. The wrappers check the return status and print diagnostic information (m, n, k, error code) to stderr.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment