Implementation:FMInference FlexLLMGen DeepSpeed State Dict Factory
| Field | Value |
|---|---|
| Sources | Repo: FlexLLMGen, Upstream: DeepSpeed |
| Domains | Checkpointing, Model_Loading |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Vendored DeepSpeed module that provides a factory pattern for loading model checkpoints from different formats (Megatron, BLOOM, etc.), with support for merging, splitting, and quantizing state dictionaries across model-parallel ranks.
Description
The state_dict_factory.py file (474 lines) is a vendored copy of DeepSpeed's checkpoint loading abstraction. It handles the complexity of loading pre-trained model weights that may have been saved with different parallelism configurations than the current deployment.
Key components include:
- SDLoaderFactory -- A static factory class with two entry points:
- get_sd_loader_json -- Loads a JSON manifest describing the checkpoint type, file list, version, parallelization strategy, and model-parallel size. Supports Megatron and ds_model/BLOOM types.
- get_sd_loader -- Creates the appropriate loader instance based on checkpoint type.
- SDLoaderBase (abstract) -- The base class for all state dict loaders, providing:
- load() -- The main entry point that handles three model-parallel configurations: (1) matching checkpoint and runtime MP size (direct load), (2) more checkpoint shards than runtime ranks (merge), and (3) fewer checkpoint shards than runtime ranks (split).
- merge_state_dict -- Merges multiple checkpoint shards into one for a given MP rank, handling key alignment and optional quantization.
- split_state_dict -- Splits a single checkpoint shard across multiple MP ranks.
- get_module / set_module -- Abstract methods for accessing the model module within the checkpoint's state dict structure.
- MegatronSDLoader -- Concrete loader for Megatron-LM checkpoints, handling the Megatron-specific state dict structure (model key with possible language_model sub-key) and auto-detecting the module path.
- WeightQuantization integration -- Optional post-load quantization of weights using group quantization for inference optimization.
The module handles multiple pipeline/tensor parallelism cases including PipeModule with mp_rank_*.pt files, PipeModule with layer_*.pt files, and non-PipeModule standard checkpoints.
Usage
This module is invoked internally during DeepSpeed inference initialization when loading pre-trained checkpoints. It is part of the vendored benchmark dependencies in FlexLLMGen.
Code Reference
| Field | Value |
|---|---|
| Repository | FlexLLMGen |
| File | benchmark/third_party/DeepSpeed/deepspeed/runtime/state_dict_factory.py |
| Lines | 1-474 |
| Type | AUTO_KEEP (vendored dependency) |
Key class signatures:
class SDLoaderFactory:
@staticmethod
def get_sd_loader_json(json_file, checkpoint_engine):
...
@staticmethod
def get_sd_loader(ckpt_list, checkpoint_engine, sd_type='Megatron', version=None):
...
class SDLoaderBase(ABC):
def load(self, mp_world_size, mp_rank, module_key=AUTO_MODULE_KEY,
is_pipe_parallel=False, quantize=False, quantize_bits=8,
quantize_groups=64, mlp_extra_grouping=True):
...
I/O Contract
Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
| json_file | str or dict | Yes | Path to checkpoint manifest JSON or dictionary |
| checkpoint_engine | CheckpointEngine | Yes | Engine for loading checkpoint files (default: TorchCheckpointEngine) |
| mp_world_size | int | Yes | Current model-parallel world size |
| mp_rank | int | Yes | Current model-parallel rank |
| quantize | bool | No | Enable post-load weight quantization (default: False) |
| quantize_bits | int | No | Quantization bit width (default: 8) |
Outputs
| Output | Type | Description |
|---|---|---|
| load_path | str | Path of the checkpoint file that was loaded |
| sd | dict | State dictionary with model weights |
| (all_scales, merge_count) | tuple | Quantization scales and merge count for bookkeeping |