Overview
FSDP2 distributed sharding engine that enables multi-GPU model parallelism with memory-efficient parameter sharding, mixed precision training, activation checkpointing, and flexible weight loading from both HuggingFace checkpoints and Distributed Checkpoints (DCP).
Description
fsdp2.py implements the core distributed training engine for PyTorch FSDP2 (Fully Sharded Data Parallel v2) in LLaMA-Factory's v1 plugin system. The module contains:
- get_transformer_layer_cls: A utility function that identifies the transformer layer class to use as the FSDP wrapping boundary. It checks the model's _no_split_modules attribute first, then falls back to inspecting model.model.layers or model.layers.
- FSDP2Engine: The main engine class that manages the full FSDP2 lifecycle:
- __init__: Initializes distributed state (rank, local_rank, world_size), configures mixed precision mode (bf16/fp16/fp32), reshard policy, CPU offloading, and device mesh from the DistributedInterface.
- get_mp_policy: Creates a MixedPrecisionPolicy with appropriate param_dtype and reduce_dtype based on the configured precision mode.
- prepare_model: The core FSDP wrapping function that:
- Identifies transformer layer classes for per-layer sharding.
- Applies fully_shard to each transformer layer and embedding layer (if untied).
- Enables gradient checkpointing with non-reentrant mode.
- Applies fully_shard to the top-level model.
- Supports CPU offloading with pin_memory.
- materialize_and_load: Handles meta-device parameter materialization and weight loading, supporting both DCP (Distributed Checkpoint) and HuggingFace checkpoint formats.
- shard_model: The high-level entry point that detects meta-device models (requiring materialization) vs. pre-loaded models and orchestrates preparation accordingly.
- _load_from_dcp: Loads weights from a PyTorch Distributed Checkpoint using get_model_state_dict and set_model_state_dict.
- _load_weights_from_hf_checkpoint: Loads weights from HuggingFace checkpoint files (safetensors or .bin), handling sharded index files, and distributing sliced tensors to DTensor parameters.
- _resolve_hf_checkpoint_dir: Resolves HuggingFace model identifiers (local paths or Hub repo IDs) to local directories, with rank-aware downloading to avoid concurrent downloads.
- _copy_weights: Handles weight copying to both regular tensors and DTensor (distributed tensor) parameters, slicing along shard dimensions as needed.
Usage
Use FSDP2Engine when training large models that exceed single-GPU memory capacity. It is instantiated by the v1 distributed training plugin system with configuration from dist_config. The engine wraps the model for sharded training, handles weight loading, and manages the distributed state lifecycle.
Code Reference
Source Location
Signature
def get_transformer_layer_cls(model: PreTrainedModel) -> type[nn.Module] | None
class FSDP2Engine:
def __init__(self, dist_config: dict)
def get_mp_policy(self) -> MixedPrecisionPolicy
def prepare_model(self, model: PreTrainedModel) -> PreTrainedModel
def materialize_and_load(self, model: PreTrainedModel, hf_model_path: str, dcp_path: str = None) -> PreTrainedModel
def shard_model(self, model: PreTrainedModel) -> PreTrainedModel
def _load_from_dcp(self, model: PreTrainedModel, dcp_path: str) -> None
def _load_weights_from_hf_checkpoint(self, model, hf_model_path) -> None
def _resolve_hf_checkpoint_dir(self, hf_model_path: str) -> str
def _copy_weights(self, param, loaded_tensor) -> None
Import
from llamafactory.v1.plugins.trainer_plugins.distributed.fsdp2 import FSDP2Engine
from llamafactory.v1.plugins.trainer_plugins.distributed.fsdp2 import get_transformer_layer_cls
I/O Contract
Inputs
FSDP2Engine.__init__
| Name |
Type |
Required |
Description
|
| dist_config |
dict |
Yes |
Configuration dictionary with keys: mixed_precision ("bf16"/"fp16"/"fp32"), reshard_after_forward (bool), offload_params (bool), pin_memory (bool), dcp_path (str or None)
|
FSDP2Engine.shard_model
| Name |
Type |
Required |
Description
|
| model |
PreTrainedModel |
Yes |
The model to shard; may be on meta device (requiring materialization) or already loaded
|
FSDP2Engine.materialize_and_load
| Name |
Type |
Required |
Description
|
| model |
PreTrainedModel |
Yes |
The model with meta-device parameters to materialize
|
| hf_model_path |
str |
Yes |
Path to HuggingFace checkpoint directory or Hub model ID
|
| dcp_path |
str |
No |
Path to a Distributed Checkpoint directory; if valid, used for efficient sharded loading
|
Outputs
FSDP2Engine.shard_model
| Name |
Type |
Description
|
| model |
PreTrainedModel |
The FSDP2-wrapped model with sharded parameters, ready for distributed training
|
get_transformer_layer_cls
| Name |
Type |
Description
|
| layer_cls |
type[nn.Module] or None |
The transformer layer class to use as the FSDP wrapping boundary, or None if not found
|
Usage Examples
# Initialize FSDP2 engine with configuration
from llamafactory.v1.plugins.trainer_plugins.distributed.fsdp2 import FSDP2Engine
engine = FSDP2Engine(dist_config={
"mixed_precision": "bf16",
"reshard_after_forward": True,
"offload_params": False,
"pin_memory": True,
"dcp_path": None,
})
# Shard a model for distributed training
model = engine.shard_model(model)
# Identify transformer layer class
from llamafactory.v1.plugins.trainer_plugins.distributed.fsdp2 import get_transformer_layer_cls
layer_cls = get_transformer_layer_cls(model)
# e.g., LlamaDecoderLayer
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.