Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:NVIDIA TransformerEngine TE LayerNormLinear

From Leeroopedia


Overview

Concrete tool for fused layer normalization plus linear transformation provided by TransformerEngine.

Description

te.LayerNormLinear fuses LayerNorm and Linear into a single module. It supports FP8 quantization, tensor parallelism, sequence parallelism, and optionally returns the intermediate LayerNorm output. This module is a drop-in replacement for the common pattern of torch.nn.LayerNorm followed by torch.nn.Linear, providing significant performance improvements through kernel fusion.

The module internally dispatches to optimized CUDA kernels that perform the normalization and GEMM in a single pass, avoiding the materialization of the intermediate normalized tensor in global memory.

Source

  • File: transformer_engine/pytorch/module/layernorm_linear.py
  • Class: LayerNormLinear
  • Constructor: __init__ at lines L1132-1162

Import

from transformer_engine.pytorch import LayerNormLinear

or equivalently:

import transformer_engine.pytorch as te
te.LayerNormLinear

Signature

class LayerNormLinear(TransformerEngineBaseModule):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        eps: float = 1e-5,
        sequence_parallel: bool = False,
        fuse_wgrad_accumulation: bool = False,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        init_method: Optional[Callable] = None,
        bias: bool = True,
        normalization: str = "LayerNorm",
        return_bias: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        parallel_mode: Optional[str] = None,
        return_layernorm_output: bool = False,
        return_layernorm_output_gathered: bool = False,
        parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
        zero_centered_gamma: bool = False,
        device: Union[torch.device, str] = "cuda",
        ub_overlap_ag: bool = False,
        ub_overlap_rs: bool = False,
        ub_overlap_rs_dgrad: bool = False,
        ub_bulk_wgrad: bool = False,
        ub_bulk_dgrad: bool = False,
        ub_name: Optional[str] = None,
        delay_wgrad_compute: bool = False,
        symmetric_ar_type: Optional[str] = None,
        name: Optional[str] = None,
    ) -> None:

I/O

  • Input: inp: torch.Tensor -- the input tensor to be normalized and linearly transformed.
  • Output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] -- the fused LayerNorm + Linear output. When return_layernorm_output=True, returns a tuple containing both the linear output and the intermediate LayerNorm output. When return_bias=True, the bias is returned separately.

Key Parameters

Parameter Type Default Description
in_features int (required) Size of each input sample.
out_features int (required) Size of each output sample.
eps float 1e-5 Epsilon value for numerical stability in LayerNorm.
bias bool True Whether to add a learnable bias to the linear output.
normalization str "LayerNorm" Type of normalization: "LayerNorm" or "RMSNorm".
return_layernorm_output bool False If True, also returns the intermediate LayerNorm output.
return_layernorm_output_gathered bool False If True (with sequence parallelism), returns the gathered LayerNorm output.
zero_centered_gamma bool False If True, the LayerNorm gamma parameter is centered at zero (gamma = 1 + learnable).
parallel_mode Optional[str] None Tensor parallel mode: "column", "row", or None.
sequence_parallel bool False Whether to enable sequence parallelism.
tp_group Optional[dist_group_type] None Tensor parallel process group.
tp_size int 1 Tensor parallel world size.
fuse_wgrad_accumulation bool False Whether to fuse weight gradient accumulation into the WGRAD GEMM.
delay_wgrad_compute bool False Whether to delay weight gradient computation (user must call module.backward_dw).
symmetric_ar_type Optional[str] None Type of symmetric memory all-reduce: "multimem_all_reduce", "two_shot", or "one_shot".

Example

import transformer_engine.pytorch as te

# Fused LayerNorm + Linear for QKV projection
qkv_proj = te.LayerNormLinear(
    in_features=1024,
    out_features=3072,  # 3 * hidden_size for Q, K, V
    eps=1e-5,
    normalization="LayerNorm",
)
output = qkv_proj(hidden_states)

With LayerNorm output returned (e.g., for residual connections):

import transformer_engine.pytorch as te

qkv_proj = te.LayerNormLinear(
    in_features=1024,
    out_features=3072,
    return_layernorm_output=True,
)
linear_output, ln_output = qkv_proj(hidden_states)

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment