Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Huggingface Optimum Parallel Axis Solving

From Leeroopedia

Overview

Algorithm for automatically determining which tensor dimensions can be partitioned across GPUs by analyzing data flow through the computation graph.

Description

The parallel axis solver analyzes the model's FX computation graph to determine, for each tensor operation, which axis (if any) can be split across tensor-parallel ranks. It operates in three phases:

  1. Decomposition: High-level operations are decomposed into ATen primitives via decompose_and_functionalize, providing a uniform representation for analysis. Operations like F.scaled_dot_product_attention and F.cross_entropy are preserved as leaf functions.
  2. Forward propagation: Parallel axis information is propagated forward through the graph using operation-specific rules from the op_registry. Each ATen operator defines how parallel axes transform through it (e.g., for matrix multiplication, the output parallel axis depends on which input dimension is partitioned).
  3. Backtracking: When conflicts arise (an operation receives inputs with incompatible parallel axis assignments), the solver backtracks to find a globally consistent assignment.

Usage

Use as the first pass in the parallelization pipeline, after FX tracing. The solver annotates each node in the FX graph with its parallel axis assignment, which is consumed by the subsequent ParallelLayerAnnotatePass.

Theoretical Basis

Constraint propagation on a dataflow graph. Each ATen operator defines rules for how parallel axes propagate through it. The solver treats this as a constraint satisfaction problem, using forward propagation with backtracking when contradictions arise.

Key propagation rules include:

Operation Propagation Rule
aten.mm(A, B) If A is split on dim 1 (columns), output is not split. If A is split on dim 0 (rows), output is split on dim 0.
aten.t(A) Parallel axis flips: dim 0 becomes dim 1 and vice versa.
aten.add(A, B) Both inputs must have the same parallel axis; output inherits it.
aten.reshape(A, shape) Axis mapping depends on how dimensions are combined or split.
aten.embedding(W, idx) If W is split on dim 0 (vocabulary), special vocab-parallel handling applies.

This approach is inspired by Megatron-LM tensor parallelism, which established the pattern of splitting attention heads and MLP hidden dimensions across GPUs. However, while Megatron-LM uses manual parallelization rules, this solver automates the process by analyzing the computation graph.

Metadata

Key Value
Source Repository Huggingface Optimum
Source Paper Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism
Domains Distributed_Computing, Tensor_Parallelism

Related

Connections

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment