Implementation:NVIDIA TransformerEngine PyTorch Ext GEMM
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Quantization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Implements the main C++ GEMM operation and its variants (atomic GEMM, grouped GEMM) with FP8 inputs, quantized outputs, bias fusion, GeLU fusion, and communication overlap.
Description
The primary gemm function converts input tensors A and B to TE TensorWrappers, validates matrix dimensions, creates or reuses output tensor D with optional quantization, and calls nvte_gemm with parameters for accumulation, split accumulation, alpha/beta scaling, and optional communication overlap. Helper functions getGemmOutputShape and checkGemmShape handle shape computation for transposed/non-transposed matrices. te_atomic_gemm provides split-K atomic accumulation with counters for deterministic reduction. te_general_grouped_gemm handles batched GEMM with per-group m-splits and optional single-output packing. Supports FP8 block scaling (1D/2D) and delayed/current scaling modes through the quantizer abstraction.
Usage
The performance-critical matrix multiply backbone -- every linear layer in the Transformer passes through this code. Its FP8 support is the primary mechanism for achieving speedups over standard precision.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/csrc/extensions/gemm.cpp- Lines
- 1--573
Signature
namespace transformer_engine::pytorch {
py::object gemm(at::Tensor A, at::Tensor B,
at::Tensor D, at::Tensor D_bias,
py::handle D_quantizer, at::Tensor workspace,
bool accumulate, bool use_split_accumulator,
int math_sm_count, ...);
py::object te_atomic_gemm(at::Tensor A, at::Tensor B,
at::Tensor D, at::Tensor D_bias,
py::handle D_quantizer, at::Tensor workspace,
at::Tensor counter, ...);
py::object te_general_grouped_gemm(
std::vector<at::Tensor> A_list,
std::vector<at::Tensor> B_list,
std::vector<at::Tensor> D_list, ...);
std::vector<int64_t> getGemmOutputShape(
at::Tensor A, at::Tensor B, bool transa, bool transb);
}
Import
#include "../extensions.h"
#include "../common.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| A | at::Tensor |
Yes | Left matrix operand (may be FP8) |
| B | at::Tensor |
Yes | Right matrix operand (may be FP8) |
| D | at::Tensor |
No | Pre-allocated output tensor |
| D_quantizer | py::handle |
No | Quantizer for output tensor |
| workspace | at::Tensor |
Yes | cuBLAS workspace buffer |
| accumulate | bool |
No | Whether to accumulate into D |
Outputs
| Name | Type | Description |
|---|---|---|
| D | py::object |
Result of the GEMM operation (possibly quantized) |
Usage Examples
import transformer_engine_torch as tex
# Low-level GEMM call (usually called through general_gemm wrapper)
result = tex.gemm(A, B, D, D_bias, D_quantizer, workspace,
accumulate=False, use_split_accumulator=False,
math_sm_count=0)