Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine GroupedLinear

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, PyTorch, Quantization
Last Updated 2026-02-07 14:00 GMT

Overview

Implements the GroupedLinear module that performs multiple independent linear transformations simultaneously in a single fused kernel call, commonly used for Mixture-of-Experts (MoE) layers.

Description

GroupedLinear inherits from TransformerEngineBaseModule and manages num_gemms separate weight (and optional bias) parameters. The forward pass uses _GroupedLinear, a custom torch.autograd.Function that splits the input by m_splits, quantizes each split with per-GEMM quantizers when FP8 is enabled, and dispatches to general_grouped_gemm for fused execution. The backward pass computes data gradients and weight gradients through the same grouped GEMM infrastructure, with support for deferred weight gradient accumulation via WeightGradStore.

Usage

Key module for MoE architectures where tokens are routed to different experts. The grouped GEMM approach avoids launching separate GEMM kernels per expert, providing significant throughput improvement with FP8 support.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/pytorch/module/grouped_linear.py
Lines
1--992

Signature

class _GroupedLinear(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inp, m_splits, ...): ...
    @staticmethod
    def backward(ctx, grad_output): ...

class GroupedLinear(TransformerEngineBaseModule):
    def __init__(
        self, num_gemms, in_features, out_features,
        bias=True, init_method=None, ...
    ): ...
    def forward(self, inp, m_splits, ...): ...

Import

from transformer_engine.pytorch import GroupedLinear

I/O Contract

Inputs

Name Type Required Description
inp torch.Tensor Yes Input tensor (tokens to be processed by experts)
m_splits List[int] Yes Number of tokens assigned to each expert
num_gemms int Yes Number of independent linear transformations (experts)
in_features int Yes Input feature dimension
out_features int Yes Output feature dimension

Outputs

Name Type Description
output torch.Tensor Concatenated outputs from all grouped linear transformations

Usage Examples

from transformer_engine.pytorch import GroupedLinear

# Create a grouped linear for 8 experts
grouped_linear = GroupedLinear(
    num_gemms=8,
    in_features=1024,
    out_features=4096,
    bias=True,
)

# Forward pass with per-expert token splits
output = grouped_linear(tokens, m_splits=[64, 32, 48, 56, 40, 28, 52, 44])

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment