Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine Ops Linear

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, PyTorch, Distributed
Last Updated 2026-02-07 14:00 GMT

Overview

Drop-in replacement for torch.nn.Linear implemented as a FusedOperation with support for tensor parallelism, sequence parallelism, and Megatron-LM wgrad fusion.

Description

Linear is a FusedOperation that composes BasicLinear, Bias, AllReduce/ReduceScatter, and AllGather operations as needed based on the tensor parallel configuration. For row tensor parallelism, it arranges: GEMM + bias + reduction. For column or no tensor parallelism, it arranges: (gather +) GEMM + bias. The class registers weight and bias parameters directly for compatibility with standard PyTorch APIs, while delegating to the basic operations for computation. Custom state_dict/_load_from_state_dict methods ensure backward-compatible checkpointing.

Usage

Use as a drop-in replacement for torch.nn.Linear with optional tensor parallelism and FP8 quantization support.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/pytorch/ops/linear.py
Lines
1--211

Signature

class Linear(FusedOperation):
    def __init__(self, in_features, out_features, *, bias=True, device=None, dtype=None, tensor_parallel_mode=None, tensor_parallel_group=None, sequence_parallel=False, rng_state_tracker_function=None, accumulate_into_main_grad=False) -> None: ...
    def register_parameter(self, name, param) -> None: ...
    def state_dict(self, *, prefix="", **kwargs) -> dict: ...

Import

from transformer_engine.pytorch.ops.linear import Linear

I/O Contract

Inputs

Name Type Required Description
in_features int Yes Inner dimension of input tensor
out_features int Yes Inner dimension of output tensor
bias bool No Apply additive bias (default True)
tensor_parallel_mode str or None No None, "column", or "row"
sequence_parallel bool No Distribute along sequence dimension

Outputs

Name Type Description
output torch.Tensor y = x A^T + b

Usage Examples

from transformer_engine.pytorch.ops.linear import Linear

linear = Linear(4096, 4096, bias=True, tensor_parallel_mode="column", tensor_parallel_group=tp_group)
output = linear(input_tensor)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment