Implementation:NVIDIA TransformerEngine Cpp GEMM
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Quantization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Python interface to cuBLAS-backed GEMM operations with FP8 quantization, communication overlap, and grouped GEMM support for Mixture-of-Experts workloads.
Description
general_gemm is the primary GEMM function supporting FP8 inputs, output quantization, fused GeLU activation, bias addition, accumulation, and communication-computation overlap (AG+GEMM, GEMM+RS) via user buffer (UB) communicators. It validates layouts (TN, NN, NT), allocates cuBLAS workspaces (32 MiB for Hopper, 4 MiB otherwise) cached per device, and dispatches to either custom GEMM implementations for custom tensor types or the native tex.te_general_gemm. general_grouped_gemm handles batched/grouped GEMMs for MoE workloads with multi-stream cuBLAS workspace allocation. Helper functions manage workspace sizing, scale validation, and device detection for quantized tensor types.
Usage
All linear layer computations in TransformerEngine (projections, MLPs, MoE experts) flow through this GEMM interface. It is the foundation for FP8 training performance.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/cpp_extensions/gemm.py- Lines
- 1--308
Signature
def get_cublas_workspace_size_bytes() -> None: ...
def get_cublas_workspace(device: int, ub: bool, grouped_gemm: bool) -> torch.Tensor: ...
def validate_gemm_scale(scale: Optional[float], required: bool) -> float: ...
def get_tensor_device(tensor: torch.Tensor) -> int: ...
def general_gemm(
A, B, workspace, layout, out=None, bias=None, gelu=False,
gelu_input=None, grad=False, accumulate=False, ...
): ...
def general_grouped_gemm(
A_list, B_list, workspace_list, layout, out_list=None, ...
): ...
Import
from transformer_engine.pytorch.cpp_extensions.gemm import (
general_gemm,
general_grouped_gemm,
get_cublas_workspace_size_bytes,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| A | torch.Tensor |
Yes | Left operand matrix (may be quantized) |
| B | torch.Tensor |
Yes | Right operand matrix (may be quantized) |
| workspace | torch.Tensor |
Yes | cuBLAS workspace buffer |
| layout | str |
Yes | Matrix layout: "TN", "NN", or "NT" |
| out | torch.Tensor |
No | Pre-allocated output tensor |
| bias | torch.Tensor |
No | Optional bias to fuse with GEMM |
| gelu | bool |
No | Whether to fuse GeLU activation |
| accumulate | bool |
No | Whether to accumulate into output |
| ub_algo | CommOverlapType |
No | Communication overlap algorithm (AG, RS) |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor |
Result of the matrix multiplication |
| gelu_input | torch.Tensor |
Pre-GeLU activations (if gelu=True, for backward) |
Usage Examples
from transformer_engine.pytorch.cpp_extensions.gemm import general_gemm, get_cublas_workspace
workspace = get_cublas_workspace(device=0, ub=False, grouped_gemm=False)
output = general_gemm(
A=weight_fp8,
B=input_fp8,
workspace=workspace,
layout="TN",
)