Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Optimum InitializeOrLoadWeightsPass Run

From Leeroopedia

Overview

Loads pre-trained weights from safetensors files, slicing each parameter according to its ParameterMeta for the current tensor-parallel rank, and assigns the sliced weights onto the correct GPU. Parameters not found in weight files are initialized using their registered init_fn.

Source

Property Value
Pass File optimum/fx/parallelization/passes.py
Pass Lines L436-534
Weight Map Utility File optimum/fx/parallelization/utils.py
Weight Map Utility Lines L462-494

APIs

InitializeOrLoadWeightsPass

# optimum/fx/parallelization/passes.py L436-534
class InitializeOrLoadWeightsPass:
    def run(
        self,
        graph_module: GraphModule,
        ctx: ParallelExecutionCtx,
        config: Config,
    ) -> GraphModule:
        """Load or initialize all model weights for the current TP rank."""
        ...

try_collect_weight_map

# optimum/fx/parallelization/utils.py L462-494
def try_collect_weight_map(
    model_name_or_path: str,
    cache_dir: Optional[str],
    folder_path: str,
) -> Dict[str, str]:
    """Collect a mapping from parameter names to safetensors shard file paths.

    Parses model.safetensors.index.json to build the weight map.
    Falls back to a single model.safetensors file if no index exists.

    Args:
        model_name_or_path: Hub model ID or local path.
        cache_dir: Local cache directory.
        folder_path: Path to the model snapshot folder.

    Returns:
        Dictionary mapping parameter names to their source shard file paths.
    """
    ...

Import

from optimum.fx.parallelization.passes import InitializeOrLoadWeightsPass
from optimum.fx.parallelization.utils import try_collect_weight_map

Behavior

The pass executes the following steps:

Step 1: Build Weight Map

Call try_collect_weight_map to parse the safetensors index file and create a dictionary mapping parameter names to their source shard files.

Scenario Behavior
Index file exists (model.safetensors.index.json) Parse the weight_map field to build the mapping.
Single file (model.safetensors) Map all parameters to the single shard file.
No safetensors found Raise an error.

Step 2: Iterate Over Parameters

For each named parameter in the model:

  1. Look up source file in the weight map.
  2. Read ParameterMeta from the parameter to determine the parallel dimension and slice mapping.
  3. Compute rank-specific slice: Based on the TP rank and world size, determine which slice of the original tensor this rank should load.

Step 3: Load and Assign

  1. Open safetensors file for the source shard.
  2. Read the relevant slice using the safetensors partial read API.
  3. Move to device: Transfer the tensor slice to the correct GPU.
  4. Assign to parameter: Replace the meta tensor with the real tensor data.

Step 4: Handle Special Cases

Case Handling
Tied parameters Load once; share the tensor reference across all tied parameter locations.
Parameters with need_initialize=True Skip weight file loading; call the parameter's init_fn instead.
Non-parallel parameters Load the full parameter without slicing (identical copy on each rank).

Input / Output

Direction Name Type Description
Input graph_module GraphModule FX graph with parallel layer replacements (parameters on meta device).
Input ctx ParallelExecutionCtx Execution context containing device, TP rank, world size, and model path.
Input config Config Parallelization configuration.
Output (return) GraphModule Graph module with all parameters materialized on the correct GPU devices.

Example Usage

from optimum.fx.parallelization.passes import (
    ParallelAxisSolverPass,
    ParallelLayerAnnotatePass,
    ParallelLayerReplacePass,
    InitializeOrLoadWeightsPass,
)

# Run the full parallelization pipeline
graph_module = ParallelAxisSolverPass().run(graph_module, ctx, config)
graph_module = ParallelLayerAnnotatePass().run(graph_module, ctx, config)
graph_module = ParallelLayerReplacePass().run(graph_module, ctx, config)

# Final step: load weights
graph_module = InitializeOrLoadWeightsPass().run(graph_module, ctx, config)

# Model is now ready for inference/training
# All parameters are on their respective GPU devices
for name, param in graph_module.named_parameters():
    assert param.device.type == "cuda"

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment