Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Huggingface Optimum Sharded Weight Loading

From Leeroopedia

Overview

Process of loading pre-trained weights from safetensors files and distributing them across tensor-parallel ranks with appropriate slicing.

Description

After layer replacement, the model exists on the meta device with parallel layers but no actual weight data. The weight loading pass reads safetensors files, slices each parameter according to its ParameterMeta (parallel dimension, TP rank), and loads only the relevant slice onto each GPU.

The process involves several key steps:

  1. Build weight map: Parse the model.safetensors.index.json file to create a dictionary mapping each parameter name to the shard file that contains it.
  2. Iterate over parameters: For each parameter in the model, look up its source file and original name in the weight map.
  3. Slice for current rank: Using the parameter's ParameterMeta (particularly the dim and mapping fields), compute the slice of the weight tensor that belongs to the current tensor-parallel rank.
  4. Load and assign: Read only the relevant slice from the safetensors file and assign it to the parameter on the correct GPU device.
  5. Handle special cases:
    • Tied parameters: Load the weight once and share it across all references.
    • Uninitialized parameters: For parameters not found in weight files (e.g., newly added parallel communication buffers), call the parameter's init_fn to initialize them.

Usage

This is the final pass in the parallelization pipeline, executed after layer replacement. After this pass completes, the model is fully materialized on the GPUs and ready for inference or training.

Theoretical Basis

Sharded model loading. For a model split across N GPUs with tensor parallelism, each parameter is sliced along its parallel dimension into N equal parts. Rank i loads slice [i*chunk:(i+1)*chunk].

The safetensors format supports efficient partial reads without loading full tensors into memory. Key properties:

Property Benefit
Header-based metadata Tensor shapes, dtypes, and byte offsets are in the file header, enabling targeted reads.
Memory-mapped access Tensors can be read via memory mapping without loading the entire file.
Partial tensor reads Specific byte ranges can be read to extract individual tensor slices.
Zero-copy deserialization Tensor data can be used directly without copying into a separate buffer.

This means that for a 70B parameter model split across 8 GPUs, each GPU reads approximately 1/8 of the total weight data, and the safetensors format ensures that only the needed bytes are actually read from disk.

Related

Connections

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment