Implementation:Lm sys FastChat Train LoRA T5
| Knowledge Sources | |
|---|---|
| Domains | Training, NLP |
| Last Updated | 2026-02-07 06:00 GMT |
Overview
LoRA and QLoRA fine-tuning pipeline for Flan-T5 seq2seq models with DeepSpeed ZeRO-3 support.
Description
Train LoRA T5 implements parameter-efficient fine-tuning of Flan-T5 models using Low-Rank Adaptation (LoRA) and optionally quantized LoRA (QLoRA) with 4-bit NormalFloat quantization. This module builds on top of train_flant5 for data preprocessing and train_lora for DeepSpeed-compatible state dict extraction, combining them into a T5-specific LoRA training script.
The LoraArguments dataclass defines the LoRA configuration: lora_r=8 (rank of the low-rank decomposition), lora_alpha=16 (scaling factor), lora_dropout=0.05 (dropout probability applied to LoRA layers), lora_target_modules=["q", "v"] (which attention projection matrices receive LoRA adapters), and a q_lora boolean flag that enables 4-bit quantization via BitsAndBytesConfig when set to True.
When q_lora is enabled, the model is loaded with 4-bit NormalFloat (nf4) quantization and double quantization for memory efficiency, and prepare_model_for_kbit_training from the peft library is called to freeze the base model and prepare it for quantized training. The LoraConfig from peft is then applied via get_peft_model to inject trainable low-rank adapters into the specified target modules.
For model saving, the script handles DeepSpeed ZeRO-3 state dict consolidation through get_peft_state_maybe_zero_3 (imported from train_lora), which correctly gathers sharded parameters across processes before saving. When not using ZeRO-3, a standard model.save_pretrained call is used instead. The script also saves a non-LoRA state dict for compatibility, extracting only the base model weights when needed.
The train() function parses ModelArguments, DataArguments, TrainingArguments, and LoraArguments, loads the Flan-T5 model (optionally quantized), applies LoRA adapters, calls make_supervised_data_module (from train_flant5) for data preparation, and launches the HuggingFace Trainer with DeepSpeed integration.
Usage
Use this when you need parameter-efficient fine-tuning of Flan-T5 models with limited GPU memory. LoRA reduces the number of trainable parameters dramatically while maintaining model quality. Enable q_lora for further memory savings through 4-bit quantization, allowing fine-tuning of larger T5 variants on consumer-grade GPUs.
Code Reference
Source Location
- Repository: Lm_sys_FastChat
- File: fastchat/train/train_lora_t5.py
- Lines: 1-226
Key Functions
| Function | Description |
|---|---|
| train() | Main entry point: loads T5 with optional 4-bit quantization, applies LoRA adapters, trains with DeepSpeed |
Imported Functions
| Function | Source Module | Description |
|---|---|---|
| smart_tokenizer_and_embedding_resize | train_flant5 | Safely adds special tokens and resizes model embeddings |
| make_supervised_data_module | train_flant5 | Builds supervised dataset and data collator for T5 seq2seq training |
| get_peft_state_maybe_zero_3 | train_lora | Extracts LoRA state dict with DeepSpeed ZeRO-3 parameter gathering |
LoRA Configuration (LoraArguments)
| Parameter | Default | Description |
|---|---|---|
| lora_r | 8 | Rank of the low-rank decomposition matrices |
| lora_alpha | 16 | Scaling factor for LoRA updates (effective scale = lora_alpha / lora_r) |
| lora_dropout | 0.05 | Dropout probability applied to LoRA layer outputs |
| lora_target_modules | ["q", "v"] | Attention projection matrices that receive LoRA adapters |
| q_lora | False | Enable 4-bit NormalFloat quantization (QLoRA) for reduced memory usage |
PEFT Library Integration
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
# LoRA configuration
lora_config = LoraConfig(
r=lora_args.lora_r,
lora_alpha=lora_args.lora_alpha,
target_modules=lora_args.lora_target_modules,
lora_dropout=lora_args.lora_dropout,
bias="none",
task_type="SEQ_2_SEQ_LM",
)
# Apply LoRA adapters to model
model = get_peft_model(model, lora_config)
Signature
def train():
...
Import
from fastchat.train.train_lora_t5 import train
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| --model_name_or_path | str | Yes | HuggingFace model path for a Flan-T5 checkpoint (e.g., google/flan-t5-xl) |
| --data_path | str | Yes | Path to training JSON data in ShareGPT conversation format |
| --output_dir | str | Yes | Directory for saving LoRA adapter weights and checkpoints |
| --lora_r | int | No | LoRA rank (default: 8) |
| --lora_alpha | int | No | LoRA scaling factor (default: 16) |
| --lora_dropout | float | No | LoRA dropout rate (default: 0.05) |
| --lora_target_modules | list[str] | No | Target attention modules for LoRA (default: ["q", "v"]) |
| --q_lora | bool | No | Enable 4-bit QLoRA quantization (default: False) |
| --num_train_epochs | int | No | Number of training epochs |
| --per_device_train_batch_size | int | No | Batch size per GPU device during training |
| --learning_rate | float | No | Peak learning rate for the optimizer |
Outputs
| Name | Type | Description |
|---|---|---|
| lora_adapters | Files | LoRA adapter weights saved in output_dir (adapter_model.bin, adapter_config.json) |
| non_lora_state | Files | Base model state dict without LoRA parameters for compatibility |
| checkpoints | Files | Training checkpoints at configured intervals |
| trainer_state | JSON | Training state including loss curves, learning rate schedule, and metrics |
Usage Examples
# LoRA fine-tune Flan-T5-XL with DeepSpeed ZeRO-3 on 4 GPUs
torchrun --nproc_per_node=4 -m fastchat.train.train_lora_t5 \
--model_name_or_path google/flan-t5-xl \
--data_path data/dummy_conversation.json \
--output_dir ./output_lora_t5 \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--learning_rate 2e-4 \
--bf16 True \
--lora_r 8 \
--lora_alpha 16 \
--deepspeed deepspeed_config.json
# QLoRA fine-tune Flan-T5-XXL with 4-bit quantization for reduced memory
torchrun --nproc_per_node=4 -m fastchat.train.train_lora_t5 \
--model_name_or_path google/flan-t5-xxl \
--data_path data/dummy_conversation.json \
--output_dir ./output_qlora_t5 \
--num_train_epochs 3 \
--per_device_train_batch_size 2 \
--learning_rate 2e-4 \
--bf16 True \
--q_lora True
Related Pages
- Principle:Lm_sys_FastChat_DeepSpeed_LoRA_Training
- Implements: Principle:Lm_sys_FastChat_DeepSpeed_LoRA_Training
- Environment:Lm_sys_FastChat_SFT_Training_Environment
- Heuristic:Lm_sys_FastChat_Vicuna_SFT_Training_Hyperparameters
- Implementation:Lm_sys_FastChat_Train_FlanT5
- Implementation:Lm_sys_FastChat_Peft_Get_Peft_Model
- Implementation:Lm_sys_FastChat_Get_Peft_State_Maybe_Zero_3