Principle:Huggingface Optimum Parallel Axis Solving
Overview
Algorithm for automatically determining which tensor dimensions can be partitioned across GPUs by analyzing data flow through the computation graph.
Description
The parallel axis solver analyzes the model's FX computation graph to determine, for each tensor operation, which axis (if any) can be split across tensor-parallel ranks. It operates in three phases:
- Decomposition: High-level operations are decomposed into ATen primitives via
decompose_and_functionalize, providing a uniform representation for analysis. Operations likeF.scaled_dot_product_attentionandF.cross_entropyare preserved as leaf functions. - Forward propagation: Parallel axis information is propagated forward through the graph using operation-specific rules from the op_registry. Each ATen operator defines how parallel axes transform through it (e.g., for matrix multiplication, the output parallel axis depends on which input dimension is partitioned).
- Backtracking: When conflicts arise (an operation receives inputs with incompatible parallel axis assignments), the solver backtracks to find a globally consistent assignment.
Usage
Use as the first pass in the parallelization pipeline, after FX tracing. The solver annotates each node in the FX graph with its parallel axis assignment, which is consumed by the subsequent ParallelLayerAnnotatePass.
Theoretical Basis
Constraint propagation on a dataflow graph. Each ATen operator defines rules for how parallel axes propagate through it. The solver treats this as a constraint satisfaction problem, using forward propagation with backtracking when contradictions arise.
Key propagation rules include:
| Operation | Propagation Rule |
|---|---|
aten.mm(A, B) |
If A is split on dim 1 (columns), output is not split. If A is split on dim 0 (rows), output is split on dim 0. |
aten.t(A) |
Parallel axis flips: dim 0 becomes dim 1 and vice versa. |
aten.add(A, B) |
Both inputs must have the same parallel axis; output inherits it. |
aten.reshape(A, shape) |
Axis mapping depends on how dimensions are combined or split. |
aten.embedding(W, idx) |
If W is split on dim 0 (vocabulary), special vocab-parallel handling applies. |
This approach is inspired by Megatron-LM tensor parallelism, which established the pattern of splitting attention heads and MLP hidden dimensions across GPUs. However, while Megatron-LM uses manual parallelization rules, this solver automates the process by analyzing the computation graph.
Metadata
| Key | Value |
|---|---|
| Source Repository | Huggingface Optimum |
| Source Paper | Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism |
| Domains | Distributed_Computing, Tensor_Parallelism |
Related
- Implemented by: Implementation:Huggingface_Optimum_ParallelAxisSolverPass_Run
- Depends on: Principle:Huggingface_Optimum_Parameter_Metadata_Initialization
- Used by: Principle:Huggingface_Optimum_Parallel_Layer_Annotation