Implementation:Lm sys FastChat Trainer Save Model Safe
| Field | Value |
|---|---|
| Page Type | Implementation (API Doc) |
| Title | Trainer Save Model Safe |
| Repository | lm-sys/FastChat |
| Workflow | Vicuna SFT Finetuning |
| Domains | FSDP, Model Checkpointing, Distributed Training |
| Knowledge Sources | fastchat/train/train.py |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
This implementation documents the trainer_save_model_safe function, which provides a safe mechanism for saving models trained with Fully Sharded Data Parallel (FSDP). The function wraps the standard trainer.save_model() call inside an FSDP state dict context manager that gathers sharded parameters to CPU memory on rank 0 only, ensuring correct and memory-efficient model saving.
Description
The trainer_save_model_safe function solves the problem of saving a complete model checkpoint when the model's parameters are sharded across multiple GPUs by FSDP. It performs the following:
- Imports FSDP utilities: The function lazily imports
FullyShardedDataParallel,StateDictType, andFullStateDictConfigfromtorch.distributed.fsdp. The lazy import avoids requiring torch.distributed at module load time. - Configures the save policy: Creates a
FullStateDictConfigwithoffload_to_cpu=True(gathered parameters are placed in CPU RAM, not GPU memory) andrank0_only=True(only rank 0 materializes the full state dict). - Enters the state dict context: Uses
FSDP.state_dict_type()as a context manager on the trainer's model, setting the state dict type toFULL_STATE_DICTwith the configured save policy. - Saves the model: Calls
trainer.save_model()inside the context, which internally callsmodel.state_dict()and writes the checkpoint in Hugging Face format.
The complete function is concise (9 lines including the signature):
def trainer_save_model_safe(trainer: transformers.Trainer):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType, FullStateDictConfig
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
trainer.model, StateDictType.FULL_STATE_DICT, save_policy
):
trainer.save_model()
Usage
Code Reference
Source Location
fastchat/train/train.py:L81-89
Signature
def trainer_save_model_safe(trainer: transformers.Trainer):
Import
from fastchat.train.train import trainer_save_model_safe
I/O Contract
Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
trainer |
transformers.Trainer |
Yes | A trained Trainer instance whose model has been trained with FSDP. The Trainer must have a valid output_dir set in its args for saving the checkpoint. The trainer.model must be wrapped by FSDP.
|
Outputs
| Output | Type | Description |
|---|---|---|
| Saved model checkpoint | Files on disk | A complete model checkpoint saved in Hugging Face format (model weights, config, tokenizer files) at trainer.args.output_dir. The checkpoint is written only by rank 0.
|
Side Effects
- Collective communication: All FSDP ranks must call this function simultaneously. The
FSDP.state_dict_typecontext manager triggers an all-gather operation across ranks. - CPU memory usage: Rank 0 will temporarily hold the full model state dict in CPU memory. For a 7B parameter model in fp16, this requires approximately 14 GB of CPU RAM.
- Disk I/O: Only rank 0 writes the checkpoint files to disk.
Usage Examples
Standard usage within the training pipeline:
from fastchat.train.train import trainer_save_model_safe
# After training completes
model.config.use_cache = True
trainer.save_state()
if trainer.is_deepspeed_enabled:
# DeepSpeed has its own saving mechanism
trainer.save_model()
else:
# FSDP-safe saving
trainer_save_model_safe(trainer)
Understanding the call context:
The function is called at the end of the train() function in train.py (lines 308-314). The decision between trainer.save_model() and trainer_save_model_safe() depends on whether DeepSpeed is being used:
# Save model
model.config.use_cache = True
trainer.save_state()
if trainer.is_deepspeed_enabled:
trainer.save_model()
else:
trainer_save_model_safe(trainer)
This branching exists because DeepSpeed has its own state dict gathering mechanism that is incompatible with the FSDP context manager, while non-DeepSpeed distributed training (including FSDP) requires the explicit gathering step.
Verifying the saved checkpoint:
import transformers
# After training, the saved checkpoint can be loaded as a standard Hugging Face model
model = transformers.AutoModelForCausalLM.from_pretrained(
"output/vicuna-7b-sft"
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
"output/vicuna-7b-sft"
)