Implementation:Huggingface Optimum ParallelAxisSolverPass Run
Overview
Analyzes the model's FX computation graph to determine which tensor dimensions can be partitioned across tensor-parallel ranks. Uses a 3-phase algorithm: decomposition, forward propagation, and backtracking.
Source
| Property | Value |
|---|---|
| Pass File | optimum/fx/parallelization/passes.py
|
| Pass Lines | L144-230 |
| Decomposition File | optimum/fx/parallelization/decomp.py
|
| Decomposition Lines | L197-225 |
Signature
class ParallelAxisSolverPass:
def run(
self,
graph_module: GraphModule,
ctx: ParallelExecutionCtx,
config: Config,
) -> GraphModule:
"""Solve parallel axis assignments for all nodes in the graph."""
...
Import
from optimum.fx.parallelization.passes import ParallelAxisSolverPass
Supporting Function: decompose_and_functionalize
# optimum/fx/parallelization/decomp.py L197-225
def decompose_and_functionalize(
graph_module: GraphModule,
decomposition_table: Dict = core_aten_decompositions(),
leaf_function_targets: List[Callable] = [F.scaled_dot_product_attention, F.cross_entropy],
) -> Callable:
"""Decompose high-level operations to ATen primitives and functionalize.
Args:
graph_module: The FX graph module to decompose.
decomposition_table: Mapping from high-level ops to ATen decompositions.
leaf_function_targets: Functions to preserve as leaf nodes (not decomposed).
Returns:
A callable that produces the decomposed and functionalized graph.
"""
...
3-Phase Algorithm
Phase 1: Decomposition
High-level PyTorch operations are decomposed into ATen primitives using decompose_and_functionalize. This provides a uniform, low-level representation where each operation has well-defined parallel axis propagation rules.
- Uses
core_aten_decompositions()as the default decomposition table. - Preserves
F.scaled_dot_product_attentionandF.cross_entropyas leaf functions (these have custom parallel handling).
Phase 2: Forward Propagation
The solver iterates through the decomposed graph in topological order. For each node:
- Look up the node's operation in the op_registry.
- Apply the registered propagation rule to determine the output's parallel axis based on the inputs' parallel axes.
- Annotate the node with its computed parallel axis.
Phase 3: Backtracking
When the forward pass encounters a conflict (e.g., an addition receives one input split on dim 0 and another split on dim 1), the solver:
- Identifies the conflicting node.
- Backtracks to an earlier node where a different parallel axis choice is available.
- Re-propagates forward from that point.
- Repeats until a globally consistent solution is found or determines that no valid assignment exists.
Input / Output
| Direction | Name | Type | Description |
|---|---|---|---|
| Input | graph_module | GraphModule |
The FX-traced model graph to analyze. |
| Input | ctx | ParallelExecutionCtx |
Execution context containing device and process group information. |
| Input | config | Config |
Parallelization configuration. |
| Output | (return) | GraphModule |
The same graph module with parallel axis annotations on each node. |
Example Usage
from optimum.fx.parallelization.passes import ParallelAxisSolverPass
solver_pass = ParallelAxisSolverPass()
graph_module = solver_pass.run(graph_module, ctx, config)
# After running, each node has parallel axis metadata
for node in graph_module.graph.nodes:
if hasattr(node, 'meta') and 'parallel_axis' in node.meta:
print(f"{node.name}: parallel_axis={node.meta['parallel_axis']}")