Implementation:NVIDIA TransformerEngine Ops AllReduce
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Distributed |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Fusible operation that performs all-reduce (sum) on tensors across all processes, assuming redundant computation on all ranks for the backward pass.
Description
AllReduce sums tensors from all processes in the forward pass using torch.distributed.all_reduce. Since the output is assumed to be used in operations redundantly computed on all processes, gradients are already identical between processes and the backward pass is a simple pass-through (identity). When the process group size is 1, both passes are no-ops.
Usage
Used in row tensor-parallel configurations within the operation fuser to sum partial GEMM results across ranks.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/ops/basic/all_reduce.py- Lines
- 1--63
Signature
class AllReduce(BasicOperation):
def __init__(self, process_group=None, reduce_in_backward=True) -> None: ...
def op_forward(self, ctx, input_, prev_op_grad_output_quantizer, next_op_input_quantizer) -> torch.Tensor: ...
def op_backward(self, ctx, grad_output) -> Tuple[torch.Tensor, Tuple[()]]: ...
Import
from transformer_engine.pytorch.ops.basic.all_reduce import AllReduce
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ | torch.Tensor | Yes | Partial tensor to reduce |
| process_group | ProcessGroup | No | Distributed process group (default: world group) |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor | Sum-reduced tensor across all processes |
Usage Examples
from transformer_engine.pytorch.ops.basic.all_reduce import AllReduce
all_reduce_op = AllReduce(process_group=tp_group)
# In fuser pipeline: sums partial results from all ranks