Implementation:NVIDIA TransformerEngine Ops Linear
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Distributed |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Drop-in replacement for torch.nn.Linear implemented as a FusedOperation with support for tensor parallelism, sequence parallelism, and Megatron-LM wgrad fusion.
Description
Linear is a FusedOperation that composes BasicLinear, Bias, AllReduce/ReduceScatter, and AllGather operations as needed based on the tensor parallel configuration. For row tensor parallelism, it arranges: GEMM + bias + reduction. For column or no tensor parallelism, it arranges: (gather +) GEMM + bias. The class registers weight and bias parameters directly for compatibility with standard PyTorch APIs, while delegating to the basic operations for computation. Custom state_dict/_load_from_state_dict methods ensure backward-compatible checkpointing.
Usage
Use as a drop-in replacement for torch.nn.Linear with optional tensor parallelism and FP8 quantization support.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/ops/linear.py- Lines
- 1--211
Signature
class Linear(FusedOperation):
def __init__(self, in_features, out_features, *, bias=True, device=None, dtype=None, tensor_parallel_mode=None, tensor_parallel_group=None, sequence_parallel=False, rng_state_tracker_function=None, accumulate_into_main_grad=False) -> None: ...
def register_parameter(self, name, param) -> None: ...
def state_dict(self, *, prefix="", **kwargs) -> dict: ...
Import
from transformer_engine.pytorch.ops.linear import Linear
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| in_features | int | Yes | Inner dimension of input tensor |
| out_features | int | Yes | Inner dimension of output tensor |
| bias | bool | No | Apply additive bias (default True) |
| tensor_parallel_mode | str or None | No | None, "column", or "row"
|
| sequence_parallel | bool | No | Distribute along sequence dimension |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor | y = x A^T + b
|
Usage Examples
from transformer_engine.pytorch.ops.linear import Linear
linear = Linear(4096, 4096, bias=True, tensor_parallel_mode="column", tensor_parallel_group=tp_group)
output = linear(input_tensor)