Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Optimum ParallelAxisSolverPass Run

From Leeroopedia

Overview

Analyzes the model's FX computation graph to determine which tensor dimensions can be partitioned across tensor-parallel ranks. Uses a 3-phase algorithm: decomposition, forward propagation, and backtracking.

Source

Property Value
Pass File optimum/fx/parallelization/passes.py
Pass Lines L144-230
Decomposition File optimum/fx/parallelization/decomp.py
Decomposition Lines L197-225

Signature

class ParallelAxisSolverPass:
    def run(
        self,
        graph_module: GraphModule,
        ctx: ParallelExecutionCtx,
        config: Config,
    ) -> GraphModule:
        """Solve parallel axis assignments for all nodes in the graph."""
        ...

Import

from optimum.fx.parallelization.passes import ParallelAxisSolverPass

Supporting Function: decompose_and_functionalize

# optimum/fx/parallelization/decomp.py L197-225
def decompose_and_functionalize(
    graph_module: GraphModule,
    decomposition_table: Dict = core_aten_decompositions(),
    leaf_function_targets: List[Callable] = [F.scaled_dot_product_attention, F.cross_entropy],
) -> Callable:
    """Decompose high-level operations to ATen primitives and functionalize.

    Args:
        graph_module: The FX graph module to decompose.
        decomposition_table: Mapping from high-level ops to ATen decompositions.
        leaf_function_targets: Functions to preserve as leaf nodes (not decomposed).

    Returns:
        A callable that produces the decomposed and functionalized graph.
    """
    ...

3-Phase Algorithm

Phase 1: Decomposition

High-level PyTorch operations are decomposed into ATen primitives using decompose_and_functionalize. This provides a uniform, low-level representation where each operation has well-defined parallel axis propagation rules.

  • Uses core_aten_decompositions() as the default decomposition table.
  • Preserves F.scaled_dot_product_attention and F.cross_entropy as leaf functions (these have custom parallel handling).

Phase 2: Forward Propagation

The solver iterates through the decomposed graph in topological order. For each node:

  1. Look up the node's operation in the op_registry.
  2. Apply the registered propagation rule to determine the output's parallel axis based on the inputs' parallel axes.
  3. Annotate the node with its computed parallel axis.

Phase 3: Backtracking

When the forward pass encounters a conflict (e.g., an addition receives one input split on dim 0 and another split on dim 1), the solver:

  1. Identifies the conflicting node.
  2. Backtracks to an earlier node where a different parallel axis choice is available.
  3. Re-propagates forward from that point.
  4. Repeats until a globally consistent solution is found or determines that no valid assignment exists.

Input / Output

Direction Name Type Description
Input graph_module GraphModule The FX-traced model graph to analyze.
Input ctx ParallelExecutionCtx Execution context containing device and process group information.
Input config Config Parallelization configuration.
Output (return) GraphModule The same graph module with parallel axis annotations on each node.

Example Usage

from optimum.fx.parallelization.passes import ParallelAxisSolverPass

solver_pass = ParallelAxisSolverPass()
graph_module = solver_pass.run(graph_module, ctx, config)

# After running, each node has parallel axis metadata
for node in graph_module.graph.nodes:
    if hasattr(node, 'meta') and 'parallel_axis' in node.meta:
        print(f"{node.name}: parallel_axis={node.meta['parallel_axis']}")

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment