Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine PyTorch Ext GEMM

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, PyTorch, Quantization
Last Updated 2026-02-07 14:00 GMT

Overview

Implements the main C++ GEMM operation and its variants (atomic GEMM, grouped GEMM) with FP8 inputs, quantized outputs, bias fusion, GeLU fusion, and communication overlap.

Description

The primary gemm function converts input tensors A and B to TE TensorWrappers, validates matrix dimensions, creates or reuses output tensor D with optional quantization, and calls nvte_gemm with parameters for accumulation, split accumulation, alpha/beta scaling, and optional communication overlap. Helper functions getGemmOutputShape and checkGemmShape handle shape computation for transposed/non-transposed matrices. te_atomic_gemm provides split-K atomic accumulation with counters for deterministic reduction. te_general_grouped_gemm handles batched GEMM with per-group m-splits and optional single-output packing. Supports FP8 block scaling (1D/2D) and delayed/current scaling modes through the quantizer abstraction.

Usage

The performance-critical matrix multiply backbone -- every linear layer in the Transformer passes through this code. Its FP8 support is the primary mechanism for achieving speedups over standard precision.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/pytorch/csrc/extensions/gemm.cpp
Lines
1--573

Signature

namespace transformer_engine::pytorch {

py::object gemm(at::Tensor A, at::Tensor B,
    at::Tensor D, at::Tensor D_bias,
    py::handle D_quantizer, at::Tensor workspace,
    bool accumulate, bool use_split_accumulator,
    int math_sm_count, ...);

py::object te_atomic_gemm(at::Tensor A, at::Tensor B,
    at::Tensor D, at::Tensor D_bias,
    py::handle D_quantizer, at::Tensor workspace,
    at::Tensor counter, ...);

py::object te_general_grouped_gemm(
    std::vector<at::Tensor> A_list,
    std::vector<at::Tensor> B_list,
    std::vector<at::Tensor> D_list, ...);

std::vector<int64_t> getGemmOutputShape(
    at::Tensor A, at::Tensor B, bool transa, bool transb);

}

Import

#include "../extensions.h"
#include "../common.h"

I/O Contract

Inputs

Name Type Required Description
A at::Tensor Yes Left matrix operand (may be FP8)
B at::Tensor Yes Right matrix operand (may be FP8)
D at::Tensor No Pre-allocated output tensor
D_quantizer py::handle No Quantizer for output tensor
workspace at::Tensor Yes cuBLAS workspace buffer
accumulate bool No Whether to accumulate into D

Outputs

Name Type Description
D py::object Result of the GEMM operation (possibly quantized)

Usage Examples

import transformer_engine_torch as tex

# Low-level GEMM call (usually called through general_gemm wrapper)
result = tex.gemm(A, B, D, D_bias, D_quantizer, workspace,
                  accumulate=False, use_split_accumulator=False,
                  math_sm_count=0)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment