Implementation:Huggingface Optimum ParallelLayerAnnotatePass Run
Appearance
Overview
Examines each Linear, Embedding, and CrossEntropy layer in the FX graph and classifies it into a parallelization category (column, row, or vocab) based on the parallel axis assignments computed by the solver pass.
Source
| Property | Value |
|---|---|
| File | optimum/fx/parallelization/passes.py
|
| Lines | L233-288 |
| Module | optimum.fx.parallelization.passes
|
Signature
class ParallelLayerAnnotatePass:
def run(
self,
graph_module: GraphModule,
ctx: ParallelExecutionCtx,
config: Config,
) -> GraphModule:
"""Annotate each parallelizable layer with its parallelization strategy."""
...
Import
from optimum.fx.parallelization.passes import ParallelLayerAnnotatePass
Annotation Attributes
The pass annotates each relevant FX node with the following metadata attributes:
| Attribute | Type | Description |
|---|---|---|
| axis | str |
Parallelization category: "column", "row", or "vocab".
|
| gather_output | bool |
Whether to all-gather the output across ranks after the forward pass. |
| input_is_parallel | bool |
Whether the input to this layer is already partitioned across ranks. |
| sequence_parallel | bool |
Whether sequence parallelism is enabled for this layer. |
Classification Logic
The pass inspects each node that represents a call to nn.Linear, nn.Embedding, or CrossEntropyLoss:
Linear Layers
| Condition | Classification | gather_output |
|---|---|---|
| Parallel axis on output dimension (dim 0 of weight) | column | True if downstream expects full tensor, False otherwise |
| Parallel axis on input dimension (dim 1 of weight) | row | Always True (all-reduce sums partial results) |
Embedding Layers
| Condition | Classification |
|---|---|
| Parallel axis on vocabulary dimension (dim 0) | vocab |
CrossEntropy Layers
Handled specially to ensure loss computation is correct when the vocabulary is split across ranks.
Input / Output
| Direction | Name | Type | Description |
|---|---|---|---|
| Input | graph_module | GraphModule |
FX graph with parallel axis annotations from the solver pass. |
| Input | ctx | ParallelExecutionCtx |
Execution context containing device and process group information. |
| Input | config | Config |
Parallelization configuration. |
| Output | (return) | GraphModule |
The same graph module with layer-level parallelization annotations. |
Example Usage
from optimum.fx.parallelization.passes import ParallelAxisSolverPass, ParallelLayerAnnotatePass
# Phase 1: Solve parallel axes
solver_pass = ParallelAxisSolverPass()
graph_module = solver_pass.run(graph_module, ctx, config)
# Phase 2: Annotate layers
annotate_pass = ParallelLayerAnnotatePass()
graph_module = annotate_pass.run(graph_module, ctx, config)
# Inspect annotations
for node in graph_module.graph.nodes:
if hasattr(node, 'meta') and 'axis' in node.meta:
print(f"{node.name}: axis={node.meta['axis']}, gather={node.meta.get('gather_output')}")
Related
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment