Implementation:NVIDIA TransformerEngine GroupedLinear
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Quantization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Implements the GroupedLinear module that performs multiple independent linear transformations simultaneously in a single fused kernel call, commonly used for Mixture-of-Experts (MoE) layers.
Description
GroupedLinear inherits from TransformerEngineBaseModule and manages num_gemms separate weight (and optional bias) parameters. The forward pass uses _GroupedLinear, a custom torch.autograd.Function that splits the input by m_splits, quantizes each split with per-GEMM quantizers when FP8 is enabled, and dispatches to general_grouped_gemm for fused execution. The backward pass computes data gradients and weight gradients through the same grouped GEMM infrastructure, with support for deferred weight gradient accumulation via WeightGradStore.
Usage
Key module for MoE architectures where tokens are routed to different experts. The grouped GEMM approach avoids launching separate GEMM kernels per expert, providing significant throughput improvement with FP8 support.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/module/grouped_linear.py- Lines
- 1--992
Signature
class _GroupedLinear(torch.autograd.Function):
@staticmethod
def forward(ctx, inp, m_splits, ...): ...
@staticmethod
def backward(ctx, grad_output): ...
class GroupedLinear(TransformerEngineBaseModule):
def __init__(
self, num_gemms, in_features, out_features,
bias=True, init_method=None, ...
): ...
def forward(self, inp, m_splits, ...): ...
Import
from transformer_engine.pytorch import GroupedLinear
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| inp | torch.Tensor |
Yes | Input tensor (tokens to be processed by experts) |
| m_splits | List[int] |
Yes | Number of tokens assigned to each expert |
| num_gemms | int |
Yes | Number of independent linear transformations (experts) |
| in_features | int |
Yes | Input feature dimension |
| out_features | int |
Yes | Output feature dimension |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor |
Concatenated outputs from all grouped linear transformations |
Usage Examples
from transformer_engine.pytorch import GroupedLinear
# Create a grouped linear for 8 experts
grouped_linear = GroupedLinear(
num_gemms=8,
in_features=1024,
out_features=4096,
bias=True,
)
# Forward pass with per-expert token splits
output = grouped_linear(tokens, m_splits=[64, 32, 48, 56, 40, 28, 52, 44])