Implementation:NVIDIA TransformerEngine Ops BasicLinear
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Quantization, Distributed |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Core fusible operation implementing a linear transformation (y = xA^T) without bias, serving as the fundamental GEMM building block in the ops framework.
Description
Manages weight parameters with support for tensor parallelism (column and row modes), sequence parallelism, FP8 quantized weights, and Megatron-LM integration. Forward pass calls general_gemm with optional FP8 quantization of inputs/weights, all-gather for column TP, and reduce-scatter for sequence parallelism. Backward pass computes dgrad and wgrad GEMMs with support for gradient accumulation into main_grad, quantized compute, and asynchronous communication overlap. Provides static _functional_forward and _functional_backward methods used by both this op and fused operations.
Usage
The most important operation in the ops framework. Nearly every Transformer layer uses linear transformations. This handles all the complexity of distributed, mixed-precision, FP8-accelerated matrix multiplications.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/ops/basic/basic_linear.py- Lines
- 1--1075
Signature
class BasicLinear(BasicOperation):
def __init__(
self, in_features, out_features,
device=None, dtype=None,
tensor_parallel_mode=None,
tensor_parallel_group=None,
sequence_parallel=False, ...
): ...
def op_forward(self, ctx, input, ...): ...
def op_backward(self, ctx, grad_output): ...
@staticmethod
def _functional_forward(input, weight, ...): ...
@staticmethod
def _functional_backward(grad_output, input, weight, ...): ...
Import
from transformer_engine.pytorch.ops.basic import BasicLinear
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input | torch.Tensor |
Yes | Input tensor of shape (..., in_features) |
| in_features | int |
Yes | Input feature dimension |
| out_features | int |
Yes | Output feature dimension |
| tensor_parallel_mode | str |
No | "column" or "row" tensor parallelism |
| tensor_parallel_group | ProcessGroup |
No | Distributed process group for TP |
| sequence_parallel | bool |
No | Whether to use sequence parallelism |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor |
Result of linear transformation, shape (..., out_features) |
Usage Examples
from transformer_engine.pytorch.ops.basic import BasicLinear
# Create a basic linear op for column-parallel TP
linear = BasicLinear(
in_features=1024,
out_features=4096,
tensor_parallel_mode="column",
tensor_parallel_group=tp_group,
)
# Use in an operation pipeline
output = linear(input_tensor)