Implementation:Deepspeedai DeepSpeed LinearAllreduce Forward
Overview
Concrete tool for executing tensor-parallel linear layers with automatic AllReduce communication provided by the DeepSpeed library.
Implementation Type
Class (nn.Module subclasses for TP linear operations)
Detailed Description
LinearAllreduce is the row-parallel linear layer variant that performs AllReduce after its forward computation. LinearLayer is the column-parallel variant. Both inherit from TensorParallel_Layer (which itself inherits from nn.Module) and are transparent replacements for nn.Linear.
LinearAllreduce (row-parallel):
__init__(module, mp_group, **kwargs): Takes the originalnn.Linearmodule, extracts its weight and bias, calls_tp_partition()to slice the weight along dim=-1 (columns) for each TP rank, and configures the weight as a TP parameter with gradient support.forward(input): Computesoutput = torch.matmul(input, weight.T), then appliesRowParallel.apply(mp_group, output, not is_training_mode())which performs the AllReduce. In training mode, the AllReduce is done through autograd-compatible custom functions. Bias is added via non-inplace addition in training mode.gather_params(params_list): AllGathers the partitioned weight across TP ranks by transposing, callingdist.all_gather_into_tensor(), and transposing back. Stores the partition indata_partitionfor later restoration._tp_partition(params_list): In training mode, usestorch.chunk()for even partitioning along dim=-1. In inference mode, uses uneven partitioning viaget_shard_size_list().
LinearLayer (column-parallel):
__init__(module, mp_group, skip_partition, **kwargs): Similar to LinearAllreduce but partitions weight along dim=0 (rows). Both weight and bias are TP parameters (bias is also partitioned).forward(input): Whentp_overlap_commis False, appliesColumnParallel.apply(mp_group, input)(identity in forward, AllReduce in backward), then computesoutput = torch.matmul(input, weight.T). When overlap is enabled, usesAsyncColumnParallel.apply()for overlapped communication.gather_params(params_list): AllGathers along dim=0 for all parameters (including bias)._tp_partition(params_list): In training mode, usestorch.chunk()for even partitioning along dim=0.
GatherReplacedLayerParams:
- A context manager that temporarily gathers full parameters from all TP ranks.
__init__(params, module, enabled): Accepts an iterable of parameters or a single parameter. Checks if any parameter has theds_is_replaced_moduleattribute; disables if not.__enter__(): Callsparams[0].gather_params(params)to AllGather all parameters.__exit__(): Callsparams[0]._tp_partition(params)to re-partition parameters back to their sharded form.
Code Reference
- Repository: https://github.com/deepspeedai/DeepSpeed
- File:
deepspeed/module_inject/layers.py - Lines: L389-548 (LinearAllreduce and LinearLayer classes), L327-387 (GatherReplacedLayerParams)
- LinearAllreduce.forward: L403-408 -- performs matmul then RowParallel AllReduce
- LinearLayer.forward: L479-489 -- applies ColumnParallel then matmul
- GatherReplacedLayerParams: L327-387 -- context manager for parameter gathering
- Import:
from deepspeed.module_inject.layers import LinearAllreduce, LinearLayer, GatherReplacedLayerParams
Parameters
LinearAllreduce constructor:
| Parameter | Type | Required | Default | Description |
|---|---|---|---|---|
| module | nn.Linear | Yes | — | The original linear layer to replace |
| mp_group | ProcessGroup | Yes | — | The TP communication group |
| **kwargs | keyword arguments | No | — | Additional args (e.g., name for shard tracking) |
LinearLayer constructor:
| Parameter | Type | Required | Default | Description |
|---|---|---|---|---|
| module | nn.Linear | Yes | — | The original linear layer to replace |
| mp_group | ProcessGroup | No | None | The TP communication group |
| skip_partition | bool | No | False | Skip initial weight partitioning (for from_weights factory) |
| **kwargs | keyword arguments | No | — | Additional args (e.g., name for shard tracking) |
GatherReplacedLayerParams constructor:
| Parameter | Type | Required | Default | Description |
|---|---|---|---|---|
| params | Iterable[torch.Tensor] or torch.Tensor | Yes | — | Parameters to gather |
| module | torch.nn.Module | Yes | — | Module owning the parameters |
| enabled | bool | No | True | Enable/disable the gathering behavior |
I/O
| Direction | Name | Type | Description |
|---|---|---|---|
| Input | input | torch.Tensor | Standard forward pass input tensor |
| Output | output | torch.Tensor | Output with AllReduce applied (LinearAllreduce) or partitioned output (LinearLayer) |
Usage Example
# Standard training loop -- TP is transparent
for batch in dataloader:
outputs = engine(batch["input_ids"])
loss = criterion(outputs, batch["labels"])
engine.backward(loss)
engine.step()
# Gathering full parameters for custom operations
from deepspeed.module_inject.layers import GatherReplacedLayerParams
with GatherReplacedLayerParams(
list(engine.module.parameters(recurse=False)),
engine.module,
enabled=True
):
full_state = engine.module.state_dict()
Knowledge Sources
Relationships
Principle:Deepspeedai_DeepSpeed_Tensor_Parallel_Training
Metadata
- Workflow: AutoTP_Training
- Type: Implementation
- Last Updated: 2026-02-09 00:00 GMT