Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Lm sys FastChat FSDP Safe Model Saving

From Leeroopedia


Field Value
Page Type Principle
Title FSDP Safe Model Saving
Repository lm-sys/FastChat
Workflow Vicuna SFT Finetuning
Domains Distributed Training, FSDP, Model Checkpointing, State Dict Management
Knowledge Sources fastchat/train/train.py, PyTorch FSDP documentation, PyTorch distributed training guides
Last Updated 2026-02-07 14:00 GMT

Overview

This principle covers the theory and necessity of safely saving models that have been trained with Fully Sharded Data Parallel (FSDP). Because FSDP distributes model parameters across multiple GPUs, saving the model requires special handling to reconstruct a complete, usable state dictionary. Without proper coordination, saving from FSDP can produce incomplete checkpoints, corrupt state, or out-of-memory errors.

Description

The FSDP Sharding Problem

During FSDP training, the model's parameters are sharded (split) across all participating GPU ranks. Each rank holds only a fraction of the total parameters. This creates a fundamental challenge at save time:

  • No single rank has the complete model. If a rank attempts to call model.state_dict() directly, it will only retrieve its local shard, not the full parameter set.
  • Naive gathering can cause OOM. Simply gathering all shards onto every rank (an all-gather operation) would momentarily require each GPU to hold the full model in memory, potentially exceeding GPU memory limits.
  • File system coordination is needed. Only one rank should write the checkpoint to disk to avoid file corruption from concurrent writes.

State Dict Types

PyTorch FSDP provides several state dict types that control how parameters are materialized when state_dict() is called:

FULL_STATE_DICT

The most common type for saving. It gathers all parameter shards into a complete state dictionary:

  • When combined with FullStateDictConfig(offload_to_cpu=True, rank0_only=True), the full state dict is:
    • Gathered to CPU memory (offload_to_cpu=True): Parameters are moved from GPU to CPU RAM during gathering, avoiding GPU OOM.
    • Materialized only on rank 0 (rank0_only=True): Only the process with rank 0 ends up with the full state dict; other ranks receive empty state dicts. This prevents redundant memory usage and ensures only one process writes to disk.

SHARDED_STATE_DICT

Each rank saves its own shard independently. This is faster for saving and loading in distributed settings but requires all ranks to participate in loading and produces multiple files.

LOCAL_STATE_DICT

Each rank saves its local flattened parameters. This is the fastest option but produces non-standard checkpoints that can only be loaded with the exact same FSDP configuration.

The offload_to_cpu Strategy

The offload_to_cpu=True flag in FullStateDictConfig is critical for large models:

  • During the all-gather operation, parameter shards are collected from all ranks.
  • Instead of placing the gathered parameters in GPU memory (which may be full from the model itself, optimizer states, and activations), they are immediately offloaded to CPU RAM.
  • CPU RAM is typically much larger than GPU memory (hundreds of GB vs. tens of GB), making this approach feasible even for models that barely fit in distributed GPU memory during training.

The rank0_only Strategy

The rank0_only=True flag ensures that:

  • Only rank 0 materializes the full state dictionary in CPU memory.
  • Other ranks skip the materialization, receiving empty dicts and contributing their shards to rank 0 via collective communication.
  • This prevents N-fold memory duplication (where N is the number of ranks) that would occur if every rank materialized the full state dict.
  • Only rank 0 then writes the checkpoint to the file system, avoiding write conflicts.

Context Manager Pattern

FSDP state dict configuration is applied via a context manager:

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType, FullStateDictConfig

save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
    # Inside this context, model.state_dict() returns the full, gathered state dict
    # on rank 0 (offloaded to CPU), and an empty dict on other ranks
    state_dict = model.state_dict()

The context manager temporarily changes how state_dict() behaves for the given model, ensuring that any saving operation within the context (including trainer.save_model()) correctly gathers and saves the full model.

Usage

When saving a model trained with FSDP:

  1. Wrap the save operation in the FSDP.state_dict_type context manager.
  2. Use StateDictType.FULL_STATE_DICT with FullStateDictConfig(offload_to_cpu=True, rank0_only=True).
  3. Call the actual save method (e.g., trainer.save_model()) inside the context.
  4. Ensure all ranks enter the context manager (it involves collective communication).
  5. Only rank 0 will produce the saved checkpoint on disk.

Theoretical Basis

The FSDP safe saving pattern addresses a fundamental tension in distributed systems: the memory efficiency gained by sharding parameters across devices must be reconciled with the need for a complete artifact at save time.

The solution employs two key distributed systems principles:

  • Asymmetric computation: By designating rank 0 as the sole writer (rank0_only=True), the system avoids redundant work and storage. This is a common pattern in distributed systems where a single coordinator handles I/O while workers contribute data.
  • Memory hierarchy exploitation: By offloading to CPU (offload_to_cpu=True), the system leverages the CPU-GPU memory hierarchy. GPU memory is a scarce resource needed for training computation, while CPU memory is abundant and suitable for transient I/O operations.

These patterns ensure that model saving is both correct (the full model is gathered) and feasible (it does not exceed GPU memory constraints).

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment