Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine Ops BasicLinear

From Leeroopedia
Revision as of 15:59, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/NVIDIA_TransformerEngine_Ops_BasicLinear.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Field Value
Sources TransformerEngine
Domains Deep_Learning, PyTorch, Quantization, Distributed
Last Updated 2026-02-07 14:00 GMT

Overview

Core fusible operation implementing a linear transformation (y = xA^T) without bias, serving as the fundamental GEMM building block in the ops framework.

Description

Manages weight parameters with support for tensor parallelism (column and row modes), sequence parallelism, FP8 quantized weights, and Megatron-LM integration. Forward pass calls general_gemm with optional FP8 quantization of inputs/weights, all-gather for column TP, and reduce-scatter for sequence parallelism. Backward pass computes dgrad and wgrad GEMMs with support for gradient accumulation into main_grad, quantized compute, and asynchronous communication overlap. Provides static _functional_forward and _functional_backward methods used by both this op and fused operations.

Usage

The most important operation in the ops framework. Nearly every Transformer layer uses linear transformations. This handles all the complexity of distributed, mixed-precision, FP8-accelerated matrix multiplications.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/pytorch/ops/basic/basic_linear.py
Lines
1--1075

Signature

class BasicLinear(BasicOperation):
    def __init__(
        self, in_features, out_features,
        device=None, dtype=None,
        tensor_parallel_mode=None,
        tensor_parallel_group=None,
        sequence_parallel=False, ...
    ): ...

    def op_forward(self, ctx, input, ...): ...
    def op_backward(self, ctx, grad_output): ...

    @staticmethod
    def _functional_forward(input, weight, ...): ...
    @staticmethod
    def _functional_backward(grad_output, input, weight, ...): ...

Import

from transformer_engine.pytorch.ops.basic import BasicLinear

I/O Contract

Inputs

Name Type Required Description
input torch.Tensor Yes Input tensor of shape (..., in_features)
in_features int Yes Input feature dimension
out_features int Yes Output feature dimension
tensor_parallel_mode str No "column" or "row" tensor parallelism
tensor_parallel_group ProcessGroup No Distributed process group for TP
sequence_parallel bool No Whether to use sequence parallelism

Outputs

Name Type Description
output torch.Tensor Result of linear transformation, shape (..., out_features)

Usage Examples

from transformer_engine.pytorch.ops.basic import BasicLinear

# Create a basic linear op for column-parallel TP
linear = BasicLinear(
    in_features=1024,
    out_features=4096,
    tensor_parallel_mode="column",
    tensor_parallel_group=tp_group,
)

# Use in an operation pipeline
output = linear(input_tensor)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment