Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Optimum ParallelLayerReplacePass Run

From Leeroopedia

Overview

Replaces annotated standard PyTorch layers (nn.Linear, nn.Embedding, CrossEntropyLoss) with their distributed-aware parallel counterparts (ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding). Also handles hard-coded axis parameters that need adjustment for tensor parallelism.

Source

Property Value
Pass File optimum/fx/parallelization/passes.py
Pass Lines L291-433
Parallel Layers Directory optimum/fx/parallelization/parallel_layers/

APIs

ParallelLayerReplacePass

# optimum/fx/parallelization/passes.py L291-433
class ParallelLayerReplacePass:
    def run(
        self,
        graph_module: GraphModule,
        ctx: ParallelExecutionCtx,
        config: Config,
    ) -> GraphModule:
        """Replace annotated layers with parallel implementations."""
        ...

ColumnParallelLinear

# optimum/fx/parallelization/parallel_layers/linear.py L33-101
class ColumnParallelLinear(nn.Module):
    def __init__(self, ctx: ParallelExecutionCtx, linear: nn.Linear, gather_output: bool = True):
        """Column-parallel linear layer.

        Splits the weight matrix along the output dimension.
        Each rank computes a subset of output features.

        Args:
            ctx: Parallel execution context.
            linear: Original nn.Linear module to replace.
            gather_output: If True, all-gather the output across ranks.
        """
        ...

RowParallelLinear

# optimum/fx/parallelization/parallel_layers/linear.py L104-178
class RowParallelLinear(nn.Module):
    def __init__(self, ctx: ParallelExecutionCtx, linear: nn.Linear, input_is_parallel: bool = False):
        """Row-parallel linear layer.

        Splits the weight matrix along the input dimension.
        Each rank computes a partial sum; all-reduce produces the final output.

        Args:
            ctx: Parallel execution context.
            linear: Original nn.Linear module to replace.
            input_is_parallel: If True, the input is already partitioned across ranks.
        """
        ...

Import

from optimum.fx.parallelization.passes import ParallelLayerReplacePass

Static Handler Methods

The pass dispatches to static handler methods based on the layer type:

Handler Lines Description
handle_linear L298-330 Replaces nn.Linear with ColumnParallelLinear or RowParallelLinear based on the annotation's axis attribute.
handle_embedding L332-356 Replaces nn.Embedding with VocabParallelEmbedding for vocab-parallel annotated embeddings.
handle_cross_entropy L358-386 Replaces CrossEntropyLoss with a vocab-parallel-aware variant that handles partitioned logits.
handle_hard_coded_axis_param L387-420 Adjusts hard-coded dimension parameters (e.g., num_heads, hidden_size) that appear as constants in the graph and must be divided by the tensor-parallel world size.

Behavior

The pass iterates through all nodes in the FX graph and processes those with parallelization annotations:

  1. For each annotated Linear node:
    • If axis is "column": Replace with ColumnParallelLinear, passing gather_output from the annotation.
    • If axis is "row": Replace with RowParallelLinear, passing input_is_parallel from the annotation.
  2. For each annotated Embedding node:
    • If axis is "vocab": Replace with VocabParallelEmbedding.
  3. For each annotated CrossEntropy node:
    • Replace with a parallel-aware cross-entropy that handles partitioned vocabulary logits.
  4. For hard-coded axis parameters:
    • Divide dimension-related constants (e.g., number of attention heads) by the tensor-parallel world size.

Input / Output

Direction Name Type Description
Input graph_module GraphModule FX graph with layer-level parallelization annotations.
Input ctx ParallelExecutionCtx Execution context containing device, process group, and TP rank information.
Input config Config Parallelization configuration.
Output (return) GraphModule Graph module with parallel layer replacements applied.

Example Usage

from optimum.fx.parallelization.passes import (
    ParallelAxisSolverPass,
    ParallelLayerAnnotatePass,
    ParallelLayerReplacePass,
)

# Phase 1: Solve parallel axes
graph_module = ParallelAxisSolverPass().run(graph_module, ctx, config)

# Phase 2: Annotate layers
graph_module = ParallelLayerAnnotatePass().run(graph_module, ctx, config)

# Phase 3: Replace layers with parallel implementations
graph_module = ParallelLayerReplacePass().run(graph_module, ctx, config)

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment