Implementation:Lm sys FastChat Make Supervised Data Module
| Field | Value |
|---|---|
| Page Type | Implementation (API Doc) |
| Title | Make Supervised Data Module |
| Repository | lm-sys/FastChat |
| Workflow | Vicuna SFT Finetuning |
| Domains | Supervised Fine-Tuning, Data Loading, Dataset Construction |
| Knowledge Sources | fastchat/train/train.py |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
This implementation documents the make_supervised_data_module function, which is the primary entry point for constructing training and evaluation datasets in the Vicuna SFT fine-tuning pipeline. The function reads raw JSON conversation data, selects the appropriate dataset class (eager or lazy), and returns a dictionary suitable for passing directly to the Hugging Face Trainer.
Description
The make_supervised_data_module function serves as a factory that:
- Determines whether to use eager (
SupervisedDataset) or lazy (LazySupervisedDataset) preprocessing based on thedata_args.lazy_preprocessflag. - Loads the training data from the JSON file specified by
data_args.data_pathusingjson.load(). - Constructs the training dataset by passing the raw JSON data and tokenizer to the selected dataset class.
- Optionally loads and constructs an evaluation dataset from
data_args.eval_data_path, if provided. - Returns a dictionary with
"train_dataset"and"eval_dataset"keys, which can be unpacked directly into theTrainerconstructor.
The function internally calls json.load() to parse the entire JSON file into memory. For very large datasets, the lazy preprocessing mode defers tokenization to access time, but the raw JSON is still fully loaded.
Usage
Code Reference
Source Location
fastchat/train/train.py:L235-253
Signature
def make_supervised_data_module(
tokenizer: transformers.PreTrainedTokenizer,
data_args
) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
Import
from fastchat.train.train import make_supervised_data_module
I/O Contract
Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
tokenizer |
transformers.PreTrainedTokenizer |
Yes | A configured tokenizer instance with model_max_length, pad_token, and padding_side set. Used by the dataset classes to tokenize conversations.
|
data_args |
DataArguments |
Yes | A dataclass instance containing: data_path (str, path to training JSON), eval_data_path (str or None, path to evaluation JSON), lazy_preprocess (bool, whether to use lazy loading).
|
Outputs
| Key | Type | Description |
|---|---|---|
"train_dataset" |
SupervisedDataset or LazySupervisedDataset |
The constructed training dataset. Each item yields a dict with input_ids, labels, and attention_mask tensors.
|
"eval_dataset" |
SupervisedDataset, LazySupervisedDataset, or None |
The evaluation dataset, or None if eval_data_path was not provided.
|
Usage Examples
Basic usage within the training pipeline:
import transformers
from fastchat.train.train import make_supervised_data_module
# Assume tokenizer and data_args are already configured
tokenizer = transformers.AutoTokenizer.from_pretrained(
"lmsys/vicuna-7b-v1.5",
model_max_length=2048,
padding_side="right",
use_fast=False,
)
tokenizer.pad_token = tokenizer.unk_token
# Create a DataArguments instance
from dataclasses import dataclass
@dataclass
class DataArguments:
data_path: str = "data/sharegpt_clean.json"
eval_data_path: str = None
lazy_preprocess: bool = False
data_args = DataArguments(
data_path="data/sharegpt_clean.json",
eval_data_path="data/sharegpt_eval.json",
lazy_preprocess=True,
)
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
# data_module can be unpacked into Trainer
trainer = transformers.Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
**data_module, # train_dataset and eval_dataset
)
Inspecting the returned datasets:
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
train_dataset = data_module["train_dataset"]
print(f"Training samples: {len(train_dataset)}")
sample = train_dataset[0]
print(f"input_ids shape: {sample['input_ids'].shape}")
print(f"labels shape: {sample['labels'].shape}")
print(f"attention_mask shape: {sample['attention_mask'].shape}")