Implementation:Lm sys FastChat HF Trainer Train FSDP
| Field | Value |
|---|---|
| Page Type | Implementation (Wrapper Doc) |
| Title | HF Trainer Train FSDP |
| Repository | lm-sys/FastChat |
| Workflow | Vicuna SFT Finetuning |
| Domains | Distributed Training, FSDP, Hugging Face Trainer |
| Knowledge Sources | fastchat/train/train.py, Hugging Face Transformers Trainer documentation |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
This implementation documents how the Hugging Face Trainer class is instantiated and invoked within the Vicuna SFT training script to execute distributed training via FSDP. The training loop, including automatic checkpoint resumption, is handled by the Trainer.train() method, with FSDP configuration provided through TrainingArguments.
Description
The core training execution in fastchat/train/train.py consists of three steps:
Step 1: Trainer Instantiation
The Trainer is constructed with the model, tokenizer, training arguments, and the data module (training and evaluation datasets):
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
**data_module,
)
The data_module dictionary is unpacked to provide train_dataset and eval_dataset keyword arguments.
Step 2: Training with Checkpoint Resumption
Training is launched with automatic checkpoint detection:
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
If any checkpoint-* directories exist in the output directory, training resumes from the most recent checkpoint. Otherwise, training starts from scratch.
Step 3: Post-Training
After training completes, the model cache is re-enabled and the model is saved:
model.config.use_cache = True
trainer.save_state()
if trainer.is_deepspeed_enabled:
trainer.save_model()
else:
trainer_save_model_safe(trainer)
Custom TrainingArguments
The FastChat training script extends transformers.TrainingArguments with a custom TrainingArguments dataclass:
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=512,
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
Key defaults:
optim:"adamw_torch"(PyTorch native AdamW)model_max_length:512(can be overridden, e.g., 2048 for Vicuna)cache_dir:None(optional cache for downloaded artifacts)
FSDP is configured via the standard fsdp and fsdp_config fields inherited from transformers.TrainingArguments.
Training Variants
The FastChat repository provides multiple training script variants optimized for different attention implementations:
| Script | Description | Attention Mechanism |
|---|---|---|
fastchat/train/train.py |
Standard training script | Default attention |
fastchat/train/train_mem.py |
Flash Attention variant | Flash Attention 2 for memory-efficient training |
fastchat/train/train_xformers.py |
xFormers variant | xFormers memory-efficient attention |
All variants share the same Trainer-based training loop; they differ in how attention is computed during the forward pass.
Usage
Code Reference
Source Location
fastchat/train/train.py:L300-306 (Trainer instantiation and training loop)
fastchat/train/train.py:L62-70 (TrainingArguments dataclass)
Signature
# Trainer construction
trainer = Trainer(
model: transformers.PreTrainedModel,
tokenizer: transformers.PreTrainedTokenizer,
args: TrainingArguments,
train_dataset: Dataset,
eval_dataset: Optional[Dataset],
)
# Training invocation
trainer.train(resume_from_checkpoint: Optional[bool] = None)
Import
from transformers import Trainer
I/O Contract
Inputs
| Parameter | Type | Description |
|---|---|---|
model |
transformers.PreTrainedModel |
The loaded causal LM with use_cache=False.
|
tokenizer |
transformers.PreTrainedTokenizer |
The configured tokenizer (used for data collation and saving). |
args |
TrainingArguments |
Extended training arguments including FSDP config, optimizer, batch size, learning rate, etc. |
train_dataset |
SupervisedDataset or LazySupervisedDataset |
Training dataset returning dicts with input_ids, labels, attention_mask.
|
eval_dataset |
SupervisedDataset, LazySupervisedDataset, or None |
Optional evaluation dataset. |
Outputs
| Output | Description |
|---|---|
| Trained model parameters | Model weights updated via SFT. |
| Checkpoints | Periodic checkpoints saved to output_dir/checkpoint-{step}.
|
| Training state | Optimizer state, scheduler state, and training metadata saved via trainer.save_state().
|
| Final model | Complete model saved in Hugging Face format via trainer_save_model_safe() or trainer.save_model().
|
Usage Examples
Launching Vicuna SFT training with FSDP:
torchrun --nproc_per_node=8 --nnode=1 \
-m fastchat.train.train \
--model_name_or_path lmsys/vicuna-7b-v1.5 \
--data_path data/sharegpt_clean.json \
--output_dir output/vicuna-7b-sft \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 8 \
--learning_rate 2e-5 \
--model_max_length 2048 \
--bf16 True \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
--lazy_preprocess True \
--save_strategy steps \
--save_steps 500
Using the Flash Attention variant for reduced memory:
torchrun --nproc_per_node=8 \
-m fastchat.train.train_mem \
--model_name_or_path lmsys/vicuna-7b-v1.5 \
--data_path data/sharegpt_clean.json \
--output_dir output/vicuna-7b-sft-flash \
--bf16 True \
--fsdp "full_shard auto_wrap"
External References
- Hugging Face Trainer documentation: https://huggingface.co/docs/transformers/main_classes/trainer