Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine Ops ReduceScatter

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, PyTorch, Distributed
Last Updated 2026-02-07 14:00 GMT

Overview

Fusible operation that performs reduce-scatter along the outer dimension, summing tensors across processes and splitting the result for sequence-parallel configurations.

Description

ReduceScatter sums tensors from all processes and splits the result along the first dimension in the forward pass using torch.distributed.reduce_scatter_tensor. The backward pass performs the inverse operation (all-gather) to distribute gradients. When the process group size is 1, both passes are no-ops.

Usage

Used in sequence-parallel configurations within the operation fuser to reduce and distribute outputs after row-parallel GEMM operations.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/pytorch/ops/basic/reduce_scatter.py
Lines
1--80

Signature

class ReduceScatter(BasicOperation):
    def __init__(self, process_group: Optional[torch.distributed.ProcessGroup] = None) -> None: ...
    def op_forward(self, ctx, input_, prev_op_grad_output_quantizer, next_op_input_quantizer) -> torch.Tensor: ...
    def op_backward(self, ctx, grad_output) -> Tuple[torch.Tensor, Tuple[()]]: ...

Import

from transformer_engine.pytorch.ops.basic.reduce_scatter import ReduceScatter

I/O Contract

Inputs

Name Type Required Description
input_ torch.Tensor Yes Input tensor (first dim must be divisible by world size)
process_group ProcessGroup No Distributed process group (default: world group)

Outputs

Name Type Description
output torch.Tensor Reduced and scattered tensor (first dim divided by world size)

Usage Examples

from transformer_engine.pytorch.ops.basic.reduce_scatter import ReduceScatter

rs_op = ReduceScatter(process_group=tp_group)
# In fuser pipeline: reduces and scatters along first dimension

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment