Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Optimum ParallelLayerAnnotatePass Run

From Leeroopedia

Overview

Examines each Linear, Embedding, and CrossEntropy layer in the FX graph and classifies it into a parallelization category (column, row, or vocab) based on the parallel axis assignments computed by the solver pass.

Source

Property Value
File optimum/fx/parallelization/passes.py
Lines L233-288
Module optimum.fx.parallelization.passes

Signature

class ParallelLayerAnnotatePass:
    def run(
        self,
        graph_module: GraphModule,
        ctx: ParallelExecutionCtx,
        config: Config,
    ) -> GraphModule:
        """Annotate each parallelizable layer with its parallelization strategy."""
        ...

Import

from optimum.fx.parallelization.passes import ParallelLayerAnnotatePass

Annotation Attributes

The pass annotates each relevant FX node with the following metadata attributes:

Attribute Type Description
axis str Parallelization category: "column", "row", or "vocab".
gather_output bool Whether to all-gather the output across ranks after the forward pass.
input_is_parallel bool Whether the input to this layer is already partitioned across ranks.
sequence_parallel bool Whether sequence parallelism is enabled for this layer.

Classification Logic

The pass inspects each node that represents a call to nn.Linear, nn.Embedding, or CrossEntropyLoss:

Linear Layers

Condition Classification gather_output
Parallel axis on output dimension (dim 0 of weight) column True if downstream expects full tensor, False otherwise
Parallel axis on input dimension (dim 1 of weight) row Always True (all-reduce sums partial results)

Embedding Layers

Condition Classification
Parallel axis on vocabulary dimension (dim 0) vocab

CrossEntropy Layers

Handled specially to ensure loss computation is correct when the vocabulary is split across ranks.

Input / Output

Direction Name Type Description
Input graph_module GraphModule FX graph with parallel axis annotations from the solver pass.
Input ctx ParallelExecutionCtx Execution context containing device and process group information.
Input config Config Parallelization configuration.
Output (return) GraphModule The same graph module with layer-level parallelization annotations.

Example Usage

from optimum.fx.parallelization.passes import ParallelAxisSolverPass, ParallelLayerAnnotatePass

# Phase 1: Solve parallel axes
solver_pass = ParallelAxisSolverPass()
graph_module = solver_pass.run(graph_module, ctx, config)

# Phase 2: Annotate layers
annotate_pass = ParallelLayerAnnotatePass()
graph_module = annotate_pass.run(graph_module, ctx, config)

# Inspect annotations
for node in graph_module.graph.nodes:
    if hasattr(node, 'meta') and 'axis' in node.meta:
        print(f"{node.name}: axis={node.meta['axis']}, gather={node.meta.get('gather_output')}")

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment