Implementation:NVIDIA TransformerEngine TE LayerNormLinear
Overview
Concrete tool for fused layer normalization plus linear transformation provided by TransformerEngine.
Description
te.LayerNormLinear fuses LayerNorm and Linear into a single module. It supports FP8 quantization, tensor parallelism, sequence parallelism, and optionally returns the intermediate LayerNorm output. This module is a drop-in replacement for the common pattern of torch.nn.LayerNorm followed by torch.nn.Linear, providing significant performance improvements through kernel fusion.
The module internally dispatches to optimized CUDA kernels that perform the normalization and GEMM in a single pass, avoiding the materialization of the intermediate normalized tensor in global memory.
Source
- File:
transformer_engine/pytorch/module/layernorm_linear.py - Class:
LayerNormLinear - Constructor:
__init__at lines L1132-1162
Import
from transformer_engine.pytorch import LayerNormLinear
or equivalently:
import transformer_engine.pytorch as te
te.LayerNormLinear
Signature
class LayerNormLinear(TransformerEngineBaseModule):
def __init__(
self,
in_features: int,
out_features: int,
eps: float = 1e-5,
sequence_parallel: bool = False,
fuse_wgrad_accumulation: bool = False,
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
init_method: Optional[Callable] = None,
bias: bool = True,
normalization: str = "LayerNorm",
return_bias: bool = False,
params_dtype: Optional[torch.dtype] = None,
parallel_mode: Optional[str] = None,
return_layernorm_output: bool = False,
return_layernorm_output_gathered: bool = False,
parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "cuda",
ub_overlap_ag: bool = False,
ub_overlap_rs: bool = False,
ub_overlap_rs_dgrad: bool = False,
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None,
name: Optional[str] = None,
) -> None:
I/O
- Input:
inp: torch.Tensor-- the input tensor to be normalized and linearly transformed. - Output:
Union[torch.Tensor, Tuple[torch.Tensor, ...]]-- the fused LayerNorm + Linear output. Whenreturn_layernorm_output=True, returns a tuple containing both the linear output and the intermediate LayerNorm output. Whenreturn_bias=True, the bias is returned separately.
Key Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
in_features |
int |
(required) | Size of each input sample. |
out_features |
int |
(required) | Size of each output sample. |
eps |
float |
1e-5 |
Epsilon value for numerical stability in LayerNorm. |
bias |
bool |
True |
Whether to add a learnable bias to the linear output. |
normalization |
str |
"LayerNorm" |
Type of normalization: "LayerNorm" or "RMSNorm".
|
return_layernorm_output |
bool |
False |
If True, also returns the intermediate LayerNorm output.
|
return_layernorm_output_gathered |
bool |
False |
If True (with sequence parallelism), returns the gathered LayerNorm output.
|
zero_centered_gamma |
bool |
False |
If True, the LayerNorm gamma parameter is centered at zero (gamma = 1 + learnable).
|
parallel_mode |
Optional[str] |
None |
Tensor parallel mode: "column", "row", or None.
|
sequence_parallel |
bool |
False |
Whether to enable sequence parallelism. |
tp_group |
Optional[dist_group_type] |
None |
Tensor parallel process group. |
tp_size |
int |
1 |
Tensor parallel world size. |
fuse_wgrad_accumulation |
bool |
False |
Whether to fuse weight gradient accumulation into the WGRAD GEMM. |
delay_wgrad_compute |
bool |
False |
Whether to delay weight gradient computation (user must call module.backward_dw).
|
symmetric_ar_type |
Optional[str] |
None |
Type of symmetric memory all-reduce: "multimem_all_reduce", "two_shot", or "one_shot".
|
Example
import transformer_engine.pytorch as te
# Fused LayerNorm + Linear for QKV projection
qkv_proj = te.LayerNormLinear(
in_features=1024,
out_features=3072, # 3 * hidden_size for Q, K, V
eps=1e-5,
normalization="LayerNorm",
)
output = qkv_proj(hidden_states)
With LayerNorm output returned (e.g., for residual connections):
import transformer_engine.pytorch as te
qkv_proj = te.LayerNormLinear(
in_features=1024,
out_features=3072,
return_layernorm_output=True,
)
linear_output, ln_output = qkv_proj(hidden_states)
Related
- Principle:NVIDIA_TransformerEngine_Fused_LayerNorm_Linear
- Environment:NVIDIA_TransformerEngine_CUDA_Toolkit_Requirements
- Environment:NVIDIA_TransformerEngine_Python_PyTorch_Requirements
- Environment:NVIDIA_TransformerEngine_GPU_Compute_Capability
- Heuristic:NVIDIA_TransformerEngine_Sequence_Length_Alignment