Overview
Supervised fine-tuning pipeline for Baichuan causal language models on multi-turn conversation data.
Description
Train Baichuan implements an end-to-end supervised fine-tuning (SFT) pipeline tailored for Baichuan-series models. The module defines three argument dataclasses (ModelArguments, DataArguments, TrainingArguments) that are parsed from the command line via HuggingFace's HfArgumentParser. Data preprocessing is handled by preprocess(sources, tokenizer), which internally calls apply_prompt_template(sources, systems) to wrap each conversation turn with Baichuan-specific prompt formatting, tokenize_conversations to convert text into token IDs, and mask_targets to apply loss masking so the model only learns to predict assistant responses. Two dataset classes are provided: SupervisedDataset, which eagerly tokenizes the entire dataset into memory at initialization, and LazySupervisedDataset, which defers tokenization to __getitem__ time for memory efficiency. For large datasets, the eager path uses Python's multiprocessing.Pool to parallelize preprocessing across CPU cores. The make_supervised_data_module(tokenizer, data_args, train_ratio=0.98) factory function loads the JSON data, selects the appropriate dataset class based on the lazy_preprocess flag, splits into train/eval sets according to train_ratio, and returns a dict containing the datasets and data collator. The train() entry point orchestrates the full workflow: it parses arguments, loads the Baichuan model and tokenizer from a pretrained checkpoint, invokes the data module factory, instantiates the HuggingFace Trainer, runs training, saves the final model state, and persists the tokenizer.
Usage
Use this when fine-tuning Baichuan or Baichuan-2 models on multi-turn conversation data in ShareGPT-style JSON format. This is the dedicated training script for the Baichuan model family and handles model-specific prompt templates and tokenization.
Code Reference
Source Location
Key Functions
| Function |
Description
|
| train() |
Main entry point: parses arguments, loads model/tokenizer, preprocesses data, runs HuggingFace Trainer
|
| preprocess(sources, tokenizer) |
Applies prompt template, tokenizes conversations, and masks targets to produce training-ready input
|
| apply_prompt_template(sources, systems) |
Wraps each conversation turn with Baichuan-specific prompt formatting tokens
|
| tokenize_conversations |
Converts formatted conversation text into token IDs using the Baichuan tokenizer
|
| mask_targets |
Applies loss masking so the trainer only computes loss on assistant response tokens
|
| make_supervised_data_module(tokenizer, data_args, train_ratio=0.98) |
Factory that loads JSON data, constructs datasets (eager or lazy), splits train/eval, returns data dict
|
Dataclasses
| Dataclass |
Description
|
| ModelArguments |
Model checkpoint path and configuration options
|
| DataArguments |
Data path, lazy preprocessing flag, and train/eval split ratio
|
| TrainingArguments |
HuggingFace TrainingArguments with additional custom fields
|
Classes
| Class |
Description
|
| SupervisedDataset |
Eagerly tokenizes all examples at init time; uses multiprocessing Pool for large datasets
|
| LazySupervisedDataset |
Defers tokenization to __getitem__ for reduced memory footprint on large datasets
|
Signature
Import
from fastchat.train.train_baichuan import train
I/O Contract
Inputs
| Name |
Type |
Required |
Description
|
| --model_name_or_path |
str |
Yes |
HuggingFace model path or local checkpoint for a Baichuan model
|
| --data_path |
str |
Yes |
Path to training JSON data in ShareGPT conversation format
|
| --output_dir |
str |
Yes |
Directory for saving model checkpoints and final weights
|
| --lazy_preprocess |
bool |
No |
If set, use LazySupervisedDataset instead of eager preprocessing (default: False)
|
| --train_ratio |
float |
No |
Fraction of data used for training vs. evaluation (default: 0.98)
|
| --num_train_epochs |
int |
No |
Number of training epochs (default from TrainingArguments)
|
| --per_device_train_batch_size |
int |
No |
Batch size per GPU device during training
|
| --learning_rate |
float |
No |
Peak learning rate for the optimizer
|
Outputs
| Name |
Type |
Description
|
| checkpoints |
Files |
Model checkpoints saved in output_dir at configured intervals
|
| final_model |
Files |
Final merged model weights and tokenizer saved at end of training
|
| trainer_state |
JSON |
Training state including loss curves, learning rate schedule, and metrics
|
Usage Examples
# Fine-tune Baichuan-13B-Chat on conversation data with 4 GPUs
torchrun --nproc_per_node=4 -m fastchat.train.train_baichuan \
--model_name_or_path baichuan-inc/Baichuan-13B-Chat \
--data_path data/dummy_conversation.json \
--output_dir ./output_baichuan \
--num_train_epochs 3 \
--per_device_train_batch_size 2 \
--learning_rate 2e-5 \
--bf16 True \
--lazy_preprocess True
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.