Implementation:Huggingface Optimum InitializeOrLoadWeightsPass Run
Appearance
Overview
Loads pre-trained weights from safetensors files, slicing each parameter according to its ParameterMeta for the current tensor-parallel rank, and assigns the sliced weights onto the correct GPU. Parameters not found in weight files are initialized using their registered init_fn.
Source
| Property | Value |
|---|---|
| Pass File | optimum/fx/parallelization/passes.py
|
| Pass Lines | L436-534 |
| Weight Map Utility File | optimum/fx/parallelization/utils.py
|
| Weight Map Utility Lines | L462-494 |
APIs
InitializeOrLoadWeightsPass
# optimum/fx/parallelization/passes.py L436-534
class InitializeOrLoadWeightsPass:
def run(
self,
graph_module: GraphModule,
ctx: ParallelExecutionCtx,
config: Config,
) -> GraphModule:
"""Load or initialize all model weights for the current TP rank."""
...
try_collect_weight_map
# optimum/fx/parallelization/utils.py L462-494
def try_collect_weight_map(
model_name_or_path: str,
cache_dir: Optional[str],
folder_path: str,
) -> Dict[str, str]:
"""Collect a mapping from parameter names to safetensors shard file paths.
Parses model.safetensors.index.json to build the weight map.
Falls back to a single model.safetensors file if no index exists.
Args:
model_name_or_path: Hub model ID or local path.
cache_dir: Local cache directory.
folder_path: Path to the model snapshot folder.
Returns:
Dictionary mapping parameter names to their source shard file paths.
"""
...
Import
from optimum.fx.parallelization.passes import InitializeOrLoadWeightsPass
from optimum.fx.parallelization.utils import try_collect_weight_map
Behavior
The pass executes the following steps:
Step 1: Build Weight Map
Call try_collect_weight_map to parse the safetensors index file and create a dictionary mapping parameter names to their source shard files.
| Scenario | Behavior |
|---|---|
| Index file exists (model.safetensors.index.json) | Parse the weight_map field to build the mapping.
|
| Single file (model.safetensors) | Map all parameters to the single shard file. |
| No safetensors found | Raise an error. |
Step 2: Iterate Over Parameters
For each named parameter in the model:
- Look up source file in the weight map.
- Read ParameterMeta from the parameter to determine the parallel dimension and slice mapping.
- Compute rank-specific slice: Based on the TP rank and world size, determine which slice of the original tensor this rank should load.
Step 3: Load and Assign
- Open safetensors file for the source shard.
- Read the relevant slice using the safetensors partial read API.
- Move to device: Transfer the tensor slice to the correct GPU.
- Assign to parameter: Replace the meta tensor with the real tensor data.
Step 4: Handle Special Cases
| Case | Handling |
|---|---|
| Tied parameters | Load once; share the tensor reference across all tied parameter locations. |
| Parameters with need_initialize=True | Skip weight file loading; call the parameter's init_fn instead. |
| Non-parallel parameters | Load the full parameter without slicing (identical copy on each rank). |
Input / Output
| Direction | Name | Type | Description |
|---|---|---|---|
| Input | graph_module | GraphModule |
FX graph with parallel layer replacements (parameters on meta device). |
| Input | ctx | ParallelExecutionCtx |
Execution context containing device, TP rank, world size, and model path. |
| Input | config | Config |
Parallelization configuration. |
| Output | (return) | GraphModule |
Graph module with all parameters materialized on the correct GPU devices. |
Example Usage
from optimum.fx.parallelization.passes import (
ParallelAxisSolverPass,
ParallelLayerAnnotatePass,
ParallelLayerReplacePass,
InitializeOrLoadWeightsPass,
)
# Run the full parallelization pipeline
graph_module = ParallelAxisSolverPass().run(graph_module, ctx, config)
graph_module = ParallelLayerAnnotatePass().run(graph_module, ctx, config)
graph_module = ParallelLayerReplacePass().run(graph_module, ctx, config)
# Final step: load weights
graph_module = InitializeOrLoadWeightsPass().run(graph_module, ctx, config)
# Model is now ready for inference/training
# All parameters are on their respective GPU devices
for name, param in graph_module.named_parameters():
assert param.device.type == "cuda"
Related
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment