Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Deepspeedai DeepSpeed LinearAllreduce Forward

From Leeroopedia


Overview

Concrete tool for executing tensor-parallel linear layers with automatic AllReduce communication provided by the DeepSpeed library.

Implementation Type

Class (nn.Module subclasses for TP linear operations)

Detailed Description

LinearAllreduce is the row-parallel linear layer variant that performs AllReduce after its forward computation. LinearLayer is the column-parallel variant. Both inherit from TensorParallel_Layer (which itself inherits from nn.Module) and are transparent replacements for nn.Linear.

LinearAllreduce (row-parallel):

  • __init__(module, mp_group, **kwargs): Takes the original nn.Linear module, extracts its weight and bias, calls _tp_partition() to slice the weight along dim=-1 (columns) for each TP rank, and configures the weight as a TP parameter with gradient support.
  • forward(input): Computes output = torch.matmul(input, weight.T), then applies RowParallel.apply(mp_group, output, not is_training_mode()) which performs the AllReduce. In training mode, the AllReduce is done through autograd-compatible custom functions. Bias is added via non-inplace addition in training mode.
  • gather_params(params_list): AllGathers the partitioned weight across TP ranks by transposing, calling dist.all_gather_into_tensor(), and transposing back. Stores the partition in data_partition for later restoration.
  • _tp_partition(params_list): In training mode, uses torch.chunk() for even partitioning along dim=-1. In inference mode, uses uneven partitioning via get_shard_size_list().

LinearLayer (column-parallel):

  • __init__(module, mp_group, skip_partition, **kwargs): Similar to LinearAllreduce but partitions weight along dim=0 (rows). Both weight and bias are TP parameters (bias is also partitioned).
  • forward(input): When tp_overlap_comm is False, applies ColumnParallel.apply(mp_group, input) (identity in forward, AllReduce in backward), then computes output = torch.matmul(input, weight.T). When overlap is enabled, uses AsyncColumnParallel.apply() for overlapped communication.
  • gather_params(params_list): AllGathers along dim=0 for all parameters (including bias).
  • _tp_partition(params_list): In training mode, uses torch.chunk() for even partitioning along dim=0.

GatherReplacedLayerParams:

  • A context manager that temporarily gathers full parameters from all TP ranks.
  • __init__(params, module, enabled): Accepts an iterable of parameters or a single parameter. Checks if any parameter has the ds_is_replaced_module attribute; disables if not.
  • __enter__(): Calls params[0].gather_params(params) to AllGather all parameters.
  • __exit__(): Calls params[0]._tp_partition(params) to re-partition parameters back to their sharded form.

Code Reference

  • Repository: https://github.com/deepspeedai/DeepSpeed
  • File: deepspeed/module_inject/layers.py
  • Lines: L389-548 (LinearAllreduce and LinearLayer classes), L327-387 (GatherReplacedLayerParams)
  • LinearAllreduce.forward: L403-408 -- performs matmul then RowParallel AllReduce
  • LinearLayer.forward: L479-489 -- applies ColumnParallel then matmul
  • GatherReplacedLayerParams: L327-387 -- context manager for parameter gathering
  • Import: from deepspeed.module_inject.layers import LinearAllreduce, LinearLayer, GatherReplacedLayerParams

Parameters

LinearAllreduce constructor:

Parameter Type Required Default Description
module nn.Linear Yes The original linear layer to replace
mp_group ProcessGroup Yes The TP communication group
**kwargs keyword arguments No Additional args (e.g., name for shard tracking)

LinearLayer constructor:

Parameter Type Required Default Description
module nn.Linear Yes The original linear layer to replace
mp_group ProcessGroup No None The TP communication group
skip_partition bool No False Skip initial weight partitioning (for from_weights factory)
**kwargs keyword arguments No Additional args (e.g., name for shard tracking)

GatherReplacedLayerParams constructor:

Parameter Type Required Default Description
params Iterable[torch.Tensor] or torch.Tensor Yes Parameters to gather
module torch.nn.Module Yes Module owning the parameters
enabled bool No True Enable/disable the gathering behavior

I/O

Direction Name Type Description
Input input torch.Tensor Standard forward pass input tensor
Output output torch.Tensor Output with AllReduce applied (LinearAllreduce) or partitioned output (LinearLayer)

Usage Example

# Standard training loop -- TP is transparent
for batch in dataloader:
    outputs = engine(batch["input_ids"])
    loss = criterion(outputs, batch["labels"])
    engine.backward(loss)
    engine.step()

# Gathering full parameters for custom operations
from deepspeed.module_inject.layers import GatherReplacedLayerParams

with GatherReplacedLayerParams(
    list(engine.module.parameters(recurse=False)),
    engine.module,
    enabled=True
):
    full_state = engine.module.state_dict()

Knowledge Sources

Relationships

Principle:Deepspeedai_DeepSpeed_Tensor_Parallel_Training

Metadata

  • Workflow: AutoTP_Training
  • Type: Implementation
  • Last Updated: 2026-02-09 00:00 GMT

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment