Implementation:Huggingface Optimum ParallelLayerReplacePass Run
Appearance
Overview
Replaces annotated standard PyTorch layers (nn.Linear, nn.Embedding, CrossEntropyLoss) with their distributed-aware parallel counterparts (ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding). Also handles hard-coded axis parameters that need adjustment for tensor parallelism.
Source
| Property | Value |
|---|---|
| Pass File | optimum/fx/parallelization/passes.py
|
| Pass Lines | L291-433 |
| Parallel Layers Directory | optimum/fx/parallelization/parallel_layers/
|
APIs
ParallelLayerReplacePass
# optimum/fx/parallelization/passes.py L291-433
class ParallelLayerReplacePass:
def run(
self,
graph_module: GraphModule,
ctx: ParallelExecutionCtx,
config: Config,
) -> GraphModule:
"""Replace annotated layers with parallel implementations."""
...
ColumnParallelLinear
# optimum/fx/parallelization/parallel_layers/linear.py L33-101
class ColumnParallelLinear(nn.Module):
def __init__(self, ctx: ParallelExecutionCtx, linear: nn.Linear, gather_output: bool = True):
"""Column-parallel linear layer.
Splits the weight matrix along the output dimension.
Each rank computes a subset of output features.
Args:
ctx: Parallel execution context.
linear: Original nn.Linear module to replace.
gather_output: If True, all-gather the output across ranks.
"""
...
RowParallelLinear
# optimum/fx/parallelization/parallel_layers/linear.py L104-178
class RowParallelLinear(nn.Module):
def __init__(self, ctx: ParallelExecutionCtx, linear: nn.Linear, input_is_parallel: bool = False):
"""Row-parallel linear layer.
Splits the weight matrix along the input dimension.
Each rank computes a partial sum; all-reduce produces the final output.
Args:
ctx: Parallel execution context.
linear: Original nn.Linear module to replace.
input_is_parallel: If True, the input is already partitioned across ranks.
"""
...
Import
from optimum.fx.parallelization.passes import ParallelLayerReplacePass
Static Handler Methods
The pass dispatches to static handler methods based on the layer type:
| Handler | Lines | Description |
|---|---|---|
| handle_linear | L298-330 | Replaces nn.Linear with ColumnParallelLinear or RowParallelLinear based on the annotation's axis attribute. |
| handle_embedding | L332-356 | Replaces nn.Embedding with VocabParallelEmbedding for vocab-parallel annotated embeddings. |
| handle_cross_entropy | L358-386 | Replaces CrossEntropyLoss with a vocab-parallel-aware variant that handles partitioned logits. |
| handle_hard_coded_axis_param | L387-420 | Adjusts hard-coded dimension parameters (e.g., num_heads, hidden_size) that appear as constants in the graph and must be divided by the tensor-parallel world size.
|
Behavior
The pass iterates through all nodes in the FX graph and processes those with parallelization annotations:
- For each annotated Linear node:
- If axis is
"column": Replace with ColumnParallelLinear, passinggather_outputfrom the annotation. - If axis is
"row": Replace with RowParallelLinear, passinginput_is_parallelfrom the annotation.
- If axis is
- For each annotated Embedding node:
- If axis is
"vocab": Replace with VocabParallelEmbedding.
- If axis is
- For each annotated CrossEntropy node:
- Replace with a parallel-aware cross-entropy that handles partitioned vocabulary logits.
- For hard-coded axis parameters:
- Divide dimension-related constants (e.g., number of attention heads) by the tensor-parallel world size.
Input / Output
| Direction | Name | Type | Description |
|---|---|---|---|
| Input | graph_module | GraphModule |
FX graph with layer-level parallelization annotations. |
| Input | ctx | ParallelExecutionCtx |
Execution context containing device, process group, and TP rank information. |
| Input | config | Config |
Parallelization configuration. |
| Output | (return) | GraphModule |
Graph module with parallel layer replacements applied. |
Example Usage
from optimum.fx.parallelization.passes import (
ParallelAxisSolverPass,
ParallelLayerAnnotatePass,
ParallelLayerReplacePass,
)
# Phase 1: Solve parallel axes
graph_module = ParallelAxisSolverPass().run(graph_module, ctx, config)
# Phase 2: Annotate layers
graph_module = ParallelLayerAnnotatePass().run(graph_module, ctx, config)
# Phase 3: Replace layers with parallel implementations
graph_module = ParallelLayerReplacePass().run(graph_module, ctx, config)
Related
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment