Implementation:NVIDIA TransformerEngine TE Linear
| Field | Value |
|---|---|
| Sources | TransformerEngine, FP8 Formats for Deep Learning |
| Domains | Deep_Learning, Optimization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
te.Linear is a concrete tool for performing FP8-capable linear transformations provided by NVIDIA's TransformerEngine library. It is a drop-in replacement for torch.nn.Linear that adds FP8/FP4 quantization, tensor parallelism, and sequence parallelism support.
Description
te.Linear applies the affine transformation y = xA^T + b to incoming data, identical to torch.nn.Linear. On NVIDIA GPUs, it replaces the standard cuBLAS GEMM with an FP8-aware GEMM that leverages Tensor Cores on Hopper and later architectures. The class inherits from TransformerEngineBaseModule, which provides FP8 recipe management, scaling factor tracking, and checkpoint compatibility for FP8 metadata.
Key capabilities beyond standard torch.nn.Linear:
- FP8 quantization: When used inside a
te.fp8_autocast()context, activations and weights are automatically quantized to FP8 (E4M3 forward, E5M2 backward) with managed per-tensor scaling factors. - Tensor parallelism: The
parallel_modeparameter supports"column"(splitsout_featuresacross TP ranks) and"row"(splitsin_featuresacross TP ranks) modes, with built-in collective communication. - Sequence parallelism: When
sequence_parallel=Trueand tensor parallelism is active, the sequence dimension is distributed across TP ranks to reduce activation memory. - Fused weight gradient accumulation: The
fuse_wgrad_accumulationoption fuses gradient computation and accumulation into a single operation whenmain_gradbuffers are available. - Communication-computation overlap: Multiple
ub_overlap_*options enable overlapping NCCL collectives with GEMM computation for latency hiding. - Parameter splitting: The
parameters_splitoption allows splitting the weight and bias along dim 0 into multiple named PyTorch parameters, useful for QKV projections.
Usage
Import te.Linear when building or converting models for FP8 training on NVIDIA GPUs. It serves as a direct replacement for torch.nn.Linear with additional parallelism and optimization options.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/module/linear.py- Class
Linear- Lines
- __init__ at L1073--1100
Signature
class Linear(TransformerEngineBaseModule):
def __init__(
self,
in_features: int,
out_features: int,
sequence_parallel: bool = False,
fuse_wgrad_accumulation: bool = False,
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
rng_tracker_name: Optional[str] = None,
init_method: Optional[Callable] = None,
bias: bool = True,
return_bias: bool = False,
params_dtype: Optional[torch.dtype] = None,
parallel_mode: Optional[str] = None,
parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
device: Union[torch.device, str] = "cuda",
ub_overlap_ag: bool = False,
ub_overlap_rs: bool = False,
ub_overlap_rs_dgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_bulk_wgrad: bool = False,
ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None,
save_original_input: bool = False,
name: Optional[str] = None,
) -> None:
Import
from transformer_engine.pytorch import Linear
# or equivalently:
import transformer_engine.pytorch as te
te.Linear
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
inp |
torch.Tensor |
Yes | Input tensor of arbitrary shape with last dimension equal to in_features
|
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor |
Result of y = xA^T + b, same shape as input except last dimension is out_features
|
| bias (optional) | torch.Tensor |
Returned only when return_bias=True; the bias vector of shape [out_features] for downstream fusion
|
Key Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
in_features |
int | required | Size of each input sample (last dimension of input tensor) |
out_features |
int | required | Size of each output sample (last dimension of output tensor) |
bias |
bool | True |
If False, the layer does not learn an additive bias
|
parallel_mode |
None / "column" / "row" |
None |
Tensor parallel mode: "column" splits output features, "row" splits input features, None disables TP
|
sequence_parallel |
bool | False |
Distributes the sequence dimension across TP ranks when TP is active |
tp_group |
ProcessGroup / None | None |
Tensor parallel process group |
tp_size |
int | 1 |
Tensor parallel world size (used when tp_group is not yet formed)
|
init_method |
Callable / None | None |
Custom weight initializer; defaults to torch.nn.init.normal_(mean=0.0, std=0.023)
|
params_dtype |
torch.dtype / None |
default dtype | Data type for allocated parameters |
return_bias |
bool | False |
If True, returns the bias separately for downstream fusion instead of adding it in the forward pass
|
fuse_wgrad_accumulation |
bool | False |
Fuses weight gradient creation and accumulation when main_grad is available
|
device |
torch.device / str |
"cuda" |
Device on which to allocate parameters |
Usage Examples
Basic Drop-in Replacement
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Float8TensorFormat, DelayedScaling
# Before: standard PyTorch
# linear = torch.nn.Linear(768, 3072)
# After: TransformerEngine drop-in replacement
linear = te.Linear(768, 3072)
# Use with FP8 autocast for FP8 acceleration
fp8_recipe = DelayedScaling()
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
output = linear(input_tensor)
Tensor-Parallel Column Linear
import transformer_engine.pytorch as te
# Column-parallel: splits out_features across TP ranks
column_linear = te.Linear(
in_features=768,
out_features=3072,
bias=True,
parallel_mode="column",
sequence_parallel=True,
tp_group=tp_group,
)
Tensor-Parallel Row Linear
import transformer_engine.pytorch as te
# Row-parallel: splits in_features across TP ranks
row_linear = te.Linear(
in_features=3072,
out_features=768,
bias=True,
parallel_mode="row",
sequence_parallel=True,
tp_group=tp_group,
)
Related Pages
- Principle:NVIDIA_TransformerEngine_Drop_In_Linear_Replacement
- Environment:NVIDIA_TransformerEngine_CUDA_Toolkit_Requirements
- Environment:NVIDIA_TransformerEngine_Python_PyTorch_Requirements
- Environment:NVIDIA_TransformerEngine_GPU_Compute_Capability
- Heuristic:NVIDIA_TransformerEngine_Sequence_Length_Alignment
- Heuristic:NVIDIA_TransformerEngine_Build_Optimization_Tips