Principle:Huggingface Optimum Parallel Layer Replacement
Overview
Process of replacing standard PyTorch layers with distributed-aware parallel layer implementations that partition weights and manage collective communication.
Description
After annotation, each annotated layer is replaced with its parallel counterpart. The replacement pass transforms the model from a single-device architecture into a distributed one by swapping standard layers with parallel variants:
- ColumnParallelLinear: Splits the weight matrix along the output dimension. Each GPU computes a portion of the output features. Optionally all-gathers the output to produce the full result.
- RowParallelLinear: Splits the weight matrix along the input dimension. Each GPU computes a partial sum, and an all-reduce produces the final output.
- VocabParallelEmbedding: Splits the embedding table along the vocabulary axis. Each GPU holds embeddings for a subset of token IDs, with masking and all-reduce to handle out-of-range lookups.
These replacements also insert the necessary collective communication operations (all-reduce, all-gather, scatter) as differentiable operations in the computation graph.
Usage
This pass runs automatically after ParallelLayerAnnotatePass. It consumes the layer annotations and produces the final distributed model with parallel layers in place.
Theoretical Basis
Megatron-LM tensor parallelism. The mathematical basis for each parallel layer type is:
ColumnParallelLinear
The weight matrix A is split column-wise into N partitions [A_1, A_2, ..., A_N]:
Y_i = X * A_i
Each GPU i computes its partition independently. If gather_output is True, the results are concatenated: Y = [Y_1, Y_2, ..., Y_N].
RowParallelLinear
The weight matrix A is split row-wise, and the input X is split correspondingly:
Y = sum(X_i * A_i) via all-reduce
Each GPU computes a partial product, and the all-reduce operation sums them to produce the final output.
Differentiable Communication
Communication is made differentiable for training using custom autograd functions:
| Function | Forward Pass | Backward Pass |
|---|---|---|
| CopyToModelParallelRegion | Identity (pass through) | All-reduce gradients |
| ReduceFromModelParallelRegion | All-reduce activations | Identity (pass through) |
| ScatterToModelParallelRegion | Scatter (split input) | All-gather gradients |
| GatherFromModelParallelRegion | All-gather activations | Scatter gradients |
This ensures that gradient flow is correct during backpropagation in distributed training.
Metadata
| Key | Value |
|---|---|
| Source Paper | Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism |
Related
- Implemented by: Implementation:Huggingface_Optimum_ParallelLayerReplacePass_Run
- Depends on: Principle:Huggingface_Optimum_Parallel_Layer_Annotation
- Used by: Principle:Huggingface_Optimum_Sharded_Weight_Loading