Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Hiyouga LLaMA Factory FSDP2 Plugin

From Leeroopedia
Revision as of 15:06, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Hiyouga_LLaMA_Factory_FSDP2_Plugin.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Machine Learning, Distributed Training, Model Parallelism
Last Updated 2026-02-06 19:00 GMT

Overview

FSDP2 distributed sharding engine that enables multi-GPU model parallelism with memory-efficient parameter sharding, mixed precision training, activation checkpointing, and flexible weight loading from both HuggingFace checkpoints and Distributed Checkpoints (DCP).

Description

fsdp2.py implements the core distributed training engine for PyTorch FSDP2 (Fully Sharded Data Parallel v2) in LLaMA-Factory's v1 plugin system. The module contains:

  • get_transformer_layer_cls: A utility function that identifies the transformer layer class to use as the FSDP wrapping boundary. It checks the model's _no_split_modules attribute first, then falls back to inspecting model.model.layers or model.layers.
  • FSDP2Engine: The main engine class that manages the full FSDP2 lifecycle:
    • __init__: Initializes distributed state (rank, local_rank, world_size), configures mixed precision mode (bf16/fp16/fp32), reshard policy, CPU offloading, and device mesh from the DistributedInterface.
    • get_mp_policy: Creates a MixedPrecisionPolicy with appropriate param_dtype and reduce_dtype based on the configured precision mode.
    • prepare_model: The core FSDP wrapping function that:
      • Identifies transformer layer classes for per-layer sharding.
      • Applies fully_shard to each transformer layer and embedding layer (if untied).
      • Enables gradient checkpointing with non-reentrant mode.
      • Applies fully_shard to the top-level model.
      • Supports CPU offloading with pin_memory.
    • materialize_and_load: Handles meta-device parameter materialization and weight loading, supporting both DCP (Distributed Checkpoint) and HuggingFace checkpoint formats.
    • shard_model: The high-level entry point that detects meta-device models (requiring materialization) vs. pre-loaded models and orchestrates preparation accordingly.
    • _load_from_dcp: Loads weights from a PyTorch Distributed Checkpoint using get_model_state_dict and set_model_state_dict.
    • _load_weights_from_hf_checkpoint: Loads weights from HuggingFace checkpoint files (safetensors or .bin), handling sharded index files, and distributing sliced tensors to DTensor parameters.
    • _resolve_hf_checkpoint_dir: Resolves HuggingFace model identifiers (local paths or Hub repo IDs) to local directories, with rank-aware downloading to avoid concurrent downloads.
    • _copy_weights: Handles weight copying to both regular tensors and DTensor (distributed tensor) parameters, slicing along shard dimensions as needed.

Usage

Use FSDP2Engine when training large models that exceed single-GPU memory capacity. It is instantiated by the v1 distributed training plugin system with configuration from dist_config. The engine wraps the model for sharded training, handles weight loading, and manages the distributed state lifecycle.

Code Reference

Source Location

Signature

def get_transformer_layer_cls(model: PreTrainedModel) -> type[nn.Module] | None

class FSDP2Engine:
    def __init__(self, dist_config: dict)
    def get_mp_policy(self) -> MixedPrecisionPolicy
    def prepare_model(self, model: PreTrainedModel) -> PreTrainedModel
    def materialize_and_load(self, model: PreTrainedModel, hf_model_path: str, dcp_path: str = None) -> PreTrainedModel
    def shard_model(self, model: PreTrainedModel) -> PreTrainedModel
    def _load_from_dcp(self, model: PreTrainedModel, dcp_path: str) -> None
    def _load_weights_from_hf_checkpoint(self, model, hf_model_path) -> None
    def _resolve_hf_checkpoint_dir(self, hf_model_path: str) -> str
    def _copy_weights(self, param, loaded_tensor) -> None

Import

from llamafactory.v1.plugins.trainer_plugins.distributed.fsdp2 import FSDP2Engine
from llamafactory.v1.plugins.trainer_plugins.distributed.fsdp2 import get_transformer_layer_cls

I/O Contract

Inputs

FSDP2Engine.__init__

Name Type Required Description
dist_config dict Yes Configuration dictionary with keys: mixed_precision ("bf16"/"fp16"/"fp32"), reshard_after_forward (bool), offload_params (bool), pin_memory (bool), dcp_path (str or None)

FSDP2Engine.shard_model

Name Type Required Description
model PreTrainedModel Yes The model to shard; may be on meta device (requiring materialization) or already loaded

FSDP2Engine.materialize_and_load

Name Type Required Description
model PreTrainedModel Yes The model with meta-device parameters to materialize
hf_model_path str Yes Path to HuggingFace checkpoint directory or Hub model ID
dcp_path str No Path to a Distributed Checkpoint directory; if valid, used for efficient sharded loading

Outputs

FSDP2Engine.shard_model

Name Type Description
model PreTrainedModel The FSDP2-wrapped model with sharded parameters, ready for distributed training

get_transformer_layer_cls

Name Type Description
layer_cls type[nn.Module] or None The transformer layer class to use as the FSDP wrapping boundary, or None if not found

Usage Examples

# Initialize FSDP2 engine with configuration
from llamafactory.v1.plugins.trainer_plugins.distributed.fsdp2 import FSDP2Engine

engine = FSDP2Engine(dist_config={
    "mixed_precision": "bf16",
    "reshard_after_forward": True,
    "offload_params": False,
    "pin_memory": True,
    "dcp_path": None,
})

# Shard a model for distributed training
model = engine.shard_model(model)

# Identify transformer layer class
from llamafactory.v1.plugins.trainer_plugins.distributed.fsdp2 import get_transformer_layer_cls

layer_cls = get_transformer_layer_cls(model)
# e.g., LlamaDecoderLayer

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment