Implementation:NVIDIA TransformerEngine Ops ReduceScatter
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Distributed |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Fusible operation that performs reduce-scatter along the outer dimension, summing tensors across processes and splitting the result for sequence-parallel configurations.
Description
ReduceScatter sums tensors from all processes and splits the result along the first dimension in the forward pass using torch.distributed.reduce_scatter_tensor. The backward pass performs the inverse operation (all-gather) to distribute gradients. When the process group size is 1, both passes are no-ops.
Usage
Used in sequence-parallel configurations within the operation fuser to reduce and distribute outputs after row-parallel GEMM operations.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/ops/basic/reduce_scatter.py- Lines
- 1--80
Signature
class ReduceScatter(BasicOperation):
def __init__(self, process_group: Optional[torch.distributed.ProcessGroup] = None) -> None: ...
def op_forward(self, ctx, input_, prev_op_grad_output_quantizer, next_op_input_quantizer) -> torch.Tensor: ...
def op_backward(self, ctx, grad_output) -> Tuple[torch.Tensor, Tuple[()]]: ...
Import
from transformer_engine.pytorch.ops.basic.reduce_scatter import ReduceScatter
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ | torch.Tensor | Yes | Input tensor (first dim must be divisible by world size) |
| process_group | ProcessGroup | No | Distributed process group (default: world group) |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor | Reduced and scattered tensor (first dim divided by world size) |
Usage Examples
from transformer_engine.pytorch.ops.basic.reduce_scatter import ReduceScatter
rs_op = ReduceScatter(process_group=tp_group)
# In fuser pipeline: reduces and scatters along first dimension