Principle:Lm sys FastChat FSDP Safe Model Saving
| Field | Value |
|---|---|
| Page Type | Principle |
| Title | FSDP Safe Model Saving |
| Repository | lm-sys/FastChat |
| Workflow | Vicuna SFT Finetuning |
| Domains | Distributed Training, FSDP, Model Checkpointing, State Dict Management |
| Knowledge Sources | fastchat/train/train.py, PyTorch FSDP documentation, PyTorch distributed training guides |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
This principle covers the theory and necessity of safely saving models that have been trained with Fully Sharded Data Parallel (FSDP). Because FSDP distributes model parameters across multiple GPUs, saving the model requires special handling to reconstruct a complete, usable state dictionary. Without proper coordination, saving from FSDP can produce incomplete checkpoints, corrupt state, or out-of-memory errors.
Description
The FSDP Sharding Problem
During FSDP training, the model's parameters are sharded (split) across all participating GPU ranks. Each rank holds only a fraction of the total parameters. This creates a fundamental challenge at save time:
- No single rank has the complete model. If a rank attempts to call
model.state_dict()directly, it will only retrieve its local shard, not the full parameter set. - Naive gathering can cause OOM. Simply gathering all shards onto every rank (an all-gather operation) would momentarily require each GPU to hold the full model in memory, potentially exceeding GPU memory limits.
- File system coordination is needed. Only one rank should write the checkpoint to disk to avoid file corruption from concurrent writes.
State Dict Types
PyTorch FSDP provides several state dict types that control how parameters are materialized when state_dict() is called:
FULL_STATE_DICT
The most common type for saving. It gathers all parameter shards into a complete state dictionary:
- When combined with
FullStateDictConfig(offload_to_cpu=True, rank0_only=True), the full state dict is:- Gathered to CPU memory (
offload_to_cpu=True): Parameters are moved from GPU to CPU RAM during gathering, avoiding GPU OOM. - Materialized only on rank 0 (
rank0_only=True): Only the process with rank 0 ends up with the full state dict; other ranks receive empty state dicts. This prevents redundant memory usage and ensures only one process writes to disk.
- Gathered to CPU memory (
SHARDED_STATE_DICT
Each rank saves its own shard independently. This is faster for saving and loading in distributed settings but requires all ranks to participate in loading and produces multiple files.
LOCAL_STATE_DICT
Each rank saves its local flattened parameters. This is the fastest option but produces non-standard checkpoints that can only be loaded with the exact same FSDP configuration.
The offload_to_cpu Strategy
The offload_to_cpu=True flag in FullStateDictConfig is critical for large models:
- During the all-gather operation, parameter shards are collected from all ranks.
- Instead of placing the gathered parameters in GPU memory (which may be full from the model itself, optimizer states, and activations), they are immediately offloaded to CPU RAM.
- CPU RAM is typically much larger than GPU memory (hundreds of GB vs. tens of GB), making this approach feasible even for models that barely fit in distributed GPU memory during training.
The rank0_only Strategy
The rank0_only=True flag ensures that:
- Only rank 0 materializes the full state dictionary in CPU memory.
- Other ranks skip the materialization, receiving empty dicts and contributing their shards to rank 0 via collective communication.
- This prevents N-fold memory duplication (where N is the number of ranks) that would occur if every rank materialized the full state dict.
- Only rank 0 then writes the checkpoint to the file system, avoiding write conflicts.
Context Manager Pattern
FSDP state dict configuration is applied via a context manager:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType, FullStateDictConfig
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
# Inside this context, model.state_dict() returns the full, gathered state dict
# on rank 0 (offloaded to CPU), and an empty dict on other ranks
state_dict = model.state_dict()
The context manager temporarily changes how state_dict() behaves for the given model, ensuring that any saving operation within the context (including trainer.save_model()) correctly gathers and saves the full model.
Usage
When saving a model trained with FSDP:
- Wrap the save operation in the
FSDP.state_dict_typecontext manager. - Use
StateDictType.FULL_STATE_DICTwithFullStateDictConfig(offload_to_cpu=True, rank0_only=True). - Call the actual save method (e.g.,
trainer.save_model()) inside the context. - Ensure all ranks enter the context manager (it involves collective communication).
- Only rank 0 will produce the saved checkpoint on disk.
Theoretical Basis
The FSDP safe saving pattern addresses a fundamental tension in distributed systems: the memory efficiency gained by sharding parameters across devices must be reconciled with the need for a complete artifact at save time.
The solution employs two key distributed systems principles:
- Asymmetric computation: By designating rank 0 as the sole writer (
rank0_only=True), the system avoids redundant work and storage. This is a common pattern in distributed systems where a single coordinator handles I/O while workers contribute data. - Memory hierarchy exploitation: By offloading to CPU (
offload_to_cpu=True), the system leverages the CPU-GPU memory hierarchy. GPU memory is a scarce resource needed for training computation, while CPU memory is abundant and suitable for transient I/O operations.
These patterns ensure that model saving is both correct (the full model is gathered) and feasible (it does not exceed GPU memory constraints).