Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:NVIDIA TransformerEngine TE Linear

From Leeroopedia


Field Value
Sources TransformerEngine, FP8 Formats for Deep Learning
Domains Deep_Learning, Optimization
Last Updated 2026-02-07 14:00 GMT

Overview

te.Linear is a concrete tool for performing FP8-capable linear transformations provided by NVIDIA's TransformerEngine library. It is a drop-in replacement for torch.nn.Linear that adds FP8/FP4 quantization, tensor parallelism, and sequence parallelism support.

Description

te.Linear applies the affine transformation y = xA^T + b to incoming data, identical to torch.nn.Linear. On NVIDIA GPUs, it replaces the standard cuBLAS GEMM with an FP8-aware GEMM that leverages Tensor Cores on Hopper and later architectures. The class inherits from TransformerEngineBaseModule, which provides FP8 recipe management, scaling factor tracking, and checkpoint compatibility for FP8 metadata.

Key capabilities beyond standard torch.nn.Linear:

  • FP8 quantization: When used inside a te.fp8_autocast() context, activations and weights are automatically quantized to FP8 (E4M3 forward, E5M2 backward) with managed per-tensor scaling factors.
  • Tensor parallelism: The parallel_mode parameter supports "column" (splits out_features across TP ranks) and "row" (splits in_features across TP ranks) modes, with built-in collective communication.
  • Sequence parallelism: When sequence_parallel=True and tensor parallelism is active, the sequence dimension is distributed across TP ranks to reduce activation memory.
  • Fused weight gradient accumulation: The fuse_wgrad_accumulation option fuses gradient computation and accumulation into a single operation when main_grad buffers are available.
  • Communication-computation overlap: Multiple ub_overlap_* options enable overlapping NCCL collectives with GEMM computation for latency hiding.
  • Parameter splitting: The parameters_split option allows splitting the weight and bias along dim 0 into multiple named PyTorch parameters, useful for QKV projections.

Usage

Import te.Linear when building or converting models for FP8 training on NVIDIA GPUs. It serves as a direct replacement for torch.nn.Linear with additional parallelism and optimization options.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/pytorch/module/linear.py
Class
Linear
Lines
__init__ at L1073--1100

Signature

class Linear(TransformerEngineBaseModule):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        sequence_parallel: bool = False,
        fuse_wgrad_accumulation: bool = False,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        rng_tracker_name: Optional[str] = None,
        init_method: Optional[Callable] = None,
        bias: bool = True,
        return_bias: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        parallel_mode: Optional[str] = None,
        parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
        device: Union[torch.device, str] = "cuda",
        ub_overlap_ag: bool = False,
        ub_overlap_rs: bool = False,
        ub_overlap_rs_dgrad: bool = False,
        ub_bulk_dgrad: bool = False,
        ub_bulk_wgrad: bool = False,
        ub_name: Optional[str] = None,
        delay_wgrad_compute: bool = False,
        symmetric_ar_type: Optional[str] = None,
        save_original_input: bool = False,
        name: Optional[str] = None,
    ) -> None:

Import

from transformer_engine.pytorch import Linear

# or equivalently:
import transformer_engine.pytorch as te
te.Linear

I/O Contract

Inputs

Name Type Required Description
inp torch.Tensor Yes Input tensor of arbitrary shape with last dimension equal to in_features

Outputs

Name Type Description
output torch.Tensor Result of y = xA^T + b, same shape as input except last dimension is out_features
bias (optional) torch.Tensor Returned only when return_bias=True; the bias vector of shape [out_features] for downstream fusion

Key Parameters

Parameter Type Default Description
in_features int required Size of each input sample (last dimension of input tensor)
out_features int required Size of each output sample (last dimension of output tensor)
bias bool True If False, the layer does not learn an additive bias
parallel_mode None / "column" / "row" None Tensor parallel mode: "column" splits output features, "row" splits input features, None disables TP
sequence_parallel bool False Distributes the sequence dimension across TP ranks when TP is active
tp_group ProcessGroup / None None Tensor parallel process group
tp_size int 1 Tensor parallel world size (used when tp_group is not yet formed)
init_method Callable / None None Custom weight initializer; defaults to torch.nn.init.normal_(mean=0.0, std=0.023)
params_dtype torch.dtype / None default dtype Data type for allocated parameters
return_bias bool False If True, returns the bias separately for downstream fusion instead of adding it in the forward pass
fuse_wgrad_accumulation bool False Fuses weight gradient creation and accumulation when main_grad is available
device torch.device / str "cuda" Device on which to allocate parameters

Usage Examples

Basic Drop-in Replacement

import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Float8TensorFormat, DelayedScaling

# Before: standard PyTorch
# linear = torch.nn.Linear(768, 3072)

# After: TransformerEngine drop-in replacement
linear = te.Linear(768, 3072)

# Use with FP8 autocast for FP8 acceleration
fp8_recipe = DelayedScaling()
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    output = linear(input_tensor)

Tensor-Parallel Column Linear

import transformer_engine.pytorch as te

# Column-parallel: splits out_features across TP ranks
column_linear = te.Linear(
    in_features=768,
    out_features=3072,
    bias=True,
    parallel_mode="column",
    sequence_parallel=True,
    tp_group=tp_group,
)

Tensor-Parallel Row Linear

import transformer_engine.pytorch as te

# Row-parallel: splits in_features across TP ranks
row_linear = te.Linear(
    in_features=3072,
    out_features=768,
    bias=True,
    parallel_mode="row",
    sequence_parallel=True,
    tp_group=tp_group,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment