Implementation:Lm sys FastChat HF Trainer Train DeepSpeed
| Knowledge Sources | |
|---|---|
| Domains | NLP, Training, Distributed Systems |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Wrapper around the HuggingFace Trainer.train() method as used in FastChat's train_lora.py, with DeepSpeed ZeRO integration for distributed LoRA training across multiple GPUs.
Description
The training execution in fastchat/train/train_lora.py uses the standard HuggingFace Trainer class, which automatically initializes a DeepSpeed engine when the --deepspeed argument points to a configuration JSON file. The script constructs the Trainer with the LoRA-wrapped model, tokenizer, training arguments, and the supervised data module. Before training, it disables the model's KV cache (incompatible with training). The script also checks for existing checkpoint directories and resumes training if checkpoints are found. After training completes, trainer.save_state() persists the final trainer state.
FastChat provides two DeepSpeed configurations:
playground/deepspeed_config_s2.json-- ZeRO Stage 2 with optimizer CPU offloading, contiguous gradients, and communication overlap.playground/deepspeed_config_s3.json-- ZeRO Stage 3 with full parameter, optimizer, and gradient partitioning plus CPU offloading.
The TrainingArguments class in train_lora.py extends transformers.TrainingArguments with an additional flash_attn: bool field. When flash_attn=True, the script calls replace_llama_attn_with_flash_attn() before model loading to monkey-patch the attention implementation.
Usage
Use this pattern when running distributed LoRA fine-tuning with DeepSpeed via FastChat's training script.
Code Reference
Source Location
- Repository: FastChat
- File:
fastchat/train/train_lora.py(lines 188-199, Trainer construction and training) - File:
fastchat/train/train_lora.py(lines 42-52,TrainingArgumentsdataclass) - Config:
playground/deepspeed_config_s2.json(ZeRO Stage 2) - Config:
playground/deepspeed_config_s3.json(ZeRO Stage 3)
Signature
# TrainingArguments dataclass (lines 42-52)
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: typing.Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=512,
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
flash_attn: bool = False
# Training execution (lines 188-199)
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
trainer = Trainer(
model=model, tokenizer=tokenizer, args=training_args, **data_module
)
model.config.use_cache = False
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
Import
from transformers import Trainer
import pathlib
from fastchat.train.train import (
DataArguments,
ModelArguments,
make_supervised_data_module,
)
from fastchat.train.llama_flash_attn_monkey_patch import (
replace_llama_attn_with_flash_attn,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | PeftModel |
Yes | LoRA-wrapped causal language model with trainable adapter parameters |
| tokenizer | PreTrainedTokenizer |
Yes | Tokenizer with pad_token set to unk_token and padding_side="right"
|
| training_args | TrainingArguments |
Yes | Extended training arguments including flash_attn, deepspeed config path, and standard HF training params
|
| data_module | dict |
Yes | Dictionary with keys train_dataset, eval_dataset, and data_collator from make_supervised_data_module()
|
| training_args.deepspeed | str or None |
No | Path to DeepSpeed JSON config file (e.g., playground/deepspeed_config_s2.json)
|
| training_args.output_dir | str |
Yes | Directory for saving checkpoints and final model |
| training_args.flash_attn | bool |
No | Enable FlashAttention monkey patch for LLaMA models; default: False
|
Outputs
| Name | Type | Description |
|---|---|---|
| Trained model | PeftModel |
Model with updated LoRA adapter weights after training |
| Checkpoints | directory |
checkpoint-*/ directories in output_dir containing model and optimizer states
|
| Trainer state | file |
trainer_state.json in output_dir with training metrics and history
|
Usage Examples
Basic LoRA Training with DeepSpeed ZeRO-2
deepspeed fastchat/train/train_lora.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--data_path data/dummy_conversation.json \
--bf16 True \
--output_dir output_lora \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 8 \
--learning_rate 2e-5 \
--model_max_length 2048 \
--deepspeed playground/deepspeed_config_s2.json
QLoRA Training with DeepSpeed ZeRO-2
deepspeed fastchat/train/train_lora.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--data_path data/dummy_conversation.json \
--bf16 True \
--output_dir output_qlora \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 8 \
--learning_rate 2e-5 \
--model_max_length 512 \
--q_lora True \
--deepspeed playground/deepspeed_config_s2.json
LoRA Training with ZeRO-3 and FlashAttention
deepspeed fastchat/train/train_lora.py \
--model_name_or_path meta-llama/Llama-2-13b-hf \
--data_path data/dummy_conversation.json \
--bf16 True \
--output_dir output_lora_z3 \
--num_train_epochs 3 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 16 \
--learning_rate 2e-5 \
--model_max_length 2048 \
--flash_attn True \
--deepspeed playground/deepspeed_config_s3.json
DeepSpeed ZeRO-2 Configuration (deepspeed_config_s2.json)
{
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu"
},
"contiguous_gradients": true,
"overlap_comm": true
},
"fp16": {
"enabled": "auto"
},
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto"
}
DeepSpeed ZeRO-3 Configuration (deepspeed_config_s3.json)
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_prefetch_bucket_size": 5e8,
"stage3_param_persistence_threshold": 1e6,
"sub_group_size": 1e12,
"stage3_gather_16bit_weights_on_model_save": true
},
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto"
}