Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine Cpp GEMM

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, PyTorch, Quantization
Last Updated 2026-02-07 14:00 GMT

Overview

Python interface to cuBLAS-backed GEMM operations with FP8 quantization, communication overlap, and grouped GEMM support for Mixture-of-Experts workloads.

Description

general_gemm is the primary GEMM function supporting FP8 inputs, output quantization, fused GeLU activation, bias addition, accumulation, and communication-computation overlap (AG+GEMM, GEMM+RS) via user buffer (UB) communicators. It validates layouts (TN, NN, NT), allocates cuBLAS workspaces (32 MiB for Hopper, 4 MiB otherwise) cached per device, and dispatches to either custom GEMM implementations for custom tensor types or the native tex.te_general_gemm. general_grouped_gemm handles batched/grouped GEMMs for MoE workloads with multi-stream cuBLAS workspace allocation. Helper functions manage workspace sizing, scale validation, and device detection for quantized tensor types.

Usage

All linear layer computations in TransformerEngine (projections, MLPs, MoE experts) flow through this GEMM interface. It is the foundation for FP8 training performance.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/pytorch/cpp_extensions/gemm.py
Lines
1--308

Signature

def get_cublas_workspace_size_bytes() -> None: ...
def get_cublas_workspace(device: int, ub: bool, grouped_gemm: bool) -> torch.Tensor: ...
def validate_gemm_scale(scale: Optional[float], required: bool) -> float: ...
def get_tensor_device(tensor: torch.Tensor) -> int: ...

def general_gemm(
    A, B, workspace, layout, out=None, bias=None, gelu=False,
    gelu_input=None, grad=False, accumulate=False, ...
): ...

def general_grouped_gemm(
    A_list, B_list, workspace_list, layout, out_list=None, ...
): ...

Import

from transformer_engine.pytorch.cpp_extensions.gemm import (
    general_gemm,
    general_grouped_gemm,
    get_cublas_workspace_size_bytes,
)

I/O Contract

Inputs

Name Type Required Description
A torch.Tensor Yes Left operand matrix (may be quantized)
B torch.Tensor Yes Right operand matrix (may be quantized)
workspace torch.Tensor Yes cuBLAS workspace buffer
layout str Yes Matrix layout: "TN", "NN", or "NT"
out torch.Tensor No Pre-allocated output tensor
bias torch.Tensor No Optional bias to fuse with GEMM
gelu bool No Whether to fuse GeLU activation
accumulate bool No Whether to accumulate into output
ub_algo CommOverlapType No Communication overlap algorithm (AG, RS)

Outputs

Name Type Description
output torch.Tensor Result of the matrix multiplication
gelu_input torch.Tensor Pre-GeLU activations (if gelu=True, for backward)

Usage Examples

from transformer_engine.pytorch.cpp_extensions.gemm import general_gemm, get_cublas_workspace

workspace = get_cublas_workspace(device=0, ub=False, grouped_gemm=False)
output = general_gemm(
    A=weight_fp8,
    B=input_fp8,
    workspace=workspace,
    layout="TN",
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment