Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lm sys FastChat Trainer Save Model Safe

From Leeroopedia


Field Value
Page Type Implementation (API Doc)
Title Trainer Save Model Safe
Repository lm-sys/FastChat
Workflow Vicuna SFT Finetuning
Domains FSDP, Model Checkpointing, Distributed Training
Knowledge Sources fastchat/train/train.py
Last Updated 2026-02-07 14:00 GMT

Overview

This implementation documents the trainer_save_model_safe function, which provides a safe mechanism for saving models trained with Fully Sharded Data Parallel (FSDP). The function wraps the standard trainer.save_model() call inside an FSDP state dict context manager that gathers sharded parameters to CPU memory on rank 0 only, ensuring correct and memory-efficient model saving.

Description

The trainer_save_model_safe function solves the problem of saving a complete model checkpoint when the model's parameters are sharded across multiple GPUs by FSDP. It performs the following:

  1. Imports FSDP utilities: The function lazily imports FullyShardedDataParallel, StateDictType, and FullStateDictConfig from torch.distributed.fsdp. The lazy import avoids requiring torch.distributed at module load time.
  2. Configures the save policy: Creates a FullStateDictConfig with offload_to_cpu=True (gathered parameters are placed in CPU RAM, not GPU memory) and rank0_only=True (only rank 0 materializes the full state dict).
  3. Enters the state dict context: Uses FSDP.state_dict_type() as a context manager on the trainer's model, setting the state dict type to FULL_STATE_DICT with the configured save policy.
  4. Saves the model: Calls trainer.save_model() inside the context, which internally calls model.state_dict() and writes the checkpoint in Hugging Face format.

The complete function is concise (9 lines including the signature):

def trainer_save_model_safe(trainer: transformers.Trainer):
    from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
    from torch.distributed.fsdp import StateDictType, FullStateDictConfig

    save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
    with FSDP.state_dict_type(
        trainer.model, StateDictType.FULL_STATE_DICT, save_policy
    ):
        trainer.save_model()

Usage

Code Reference

Source Location

fastchat/train/train.py:L81-89

Signature

def trainer_save_model_safe(trainer: transformers.Trainer):

Import

from fastchat.train.train import trainer_save_model_safe

I/O Contract

Inputs

Parameter Type Required Description
trainer transformers.Trainer Yes A trained Trainer instance whose model has been trained with FSDP. The Trainer must have a valid output_dir set in its args for saving the checkpoint. The trainer.model must be wrapped by FSDP.

Outputs

Output Type Description
Saved model checkpoint Files on disk A complete model checkpoint saved in Hugging Face format (model weights, config, tokenizer files) at trainer.args.output_dir. The checkpoint is written only by rank 0.

Side Effects

  • Collective communication: All FSDP ranks must call this function simultaneously. The FSDP.state_dict_type context manager triggers an all-gather operation across ranks.
  • CPU memory usage: Rank 0 will temporarily hold the full model state dict in CPU memory. For a 7B parameter model in fp16, this requires approximately 14 GB of CPU RAM.
  • Disk I/O: Only rank 0 writes the checkpoint files to disk.

Usage Examples

Standard usage within the training pipeline:

from fastchat.train.train import trainer_save_model_safe

# After training completes
model.config.use_cache = True
trainer.save_state()

if trainer.is_deepspeed_enabled:
    # DeepSpeed has its own saving mechanism
    trainer.save_model()
else:
    # FSDP-safe saving
    trainer_save_model_safe(trainer)

Understanding the call context:

The function is called at the end of the train() function in train.py (lines 308-314). The decision between trainer.save_model() and trainer_save_model_safe() depends on whether DeepSpeed is being used:

# Save model
model.config.use_cache = True
trainer.save_state()
if trainer.is_deepspeed_enabled:
    trainer.save_model()
else:
    trainer_save_model_safe(trainer)

This branching exists because DeepSpeed has its own state dict gathering mechanism that is incompatible with the FSDP context manager, while non-DeepSpeed distributed training (including FSDP) requires the explicit gathering step.

Verifying the saved checkpoint:

import transformers

# After training, the saved checkpoint can be loaded as a standard Hugging Face model
model = transformers.AutoModelForCausalLM.from_pretrained(
    "output/vicuna-7b-sft"
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
    "output/vicuna-7b-sft"
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment