Overview
Supervised fine-tuning pipeline for Flan-T5 sequence-to-sequence models on multi-turn conversation data.
Description
Train FlanT5 implements a supervised fine-tuning pipeline specifically designed for encoder-decoder (seq2seq) Flan-T5 models. Unlike causal LM training scripts that produce a single concatenated sequence, this module reformats multi-turn conversations into explicit question-answer pairs suitable for the T5 encoder-decoder architecture. The _form_qa(sources) helper extracts alternating human/assistant turns and pairs them, while _add_speaker_and_signal(header, source) prepends speaker role tokens and signal markers to each turn. The preprocess(sources, tokenizer) function orchestrates the full preprocessing pipeline: it calls _form_qa to produce Q&A pairs, then uses _tokenize_fn(strings, tokenizer) to tokenize the input (question) and target (answer) sequences separately for the encoder and decoder. The smart_tokenizer_and_embedding_resize(special_tokens_dict, other_tokens, tokenizer, model) utility safely adds special tokens and other tokens to the tokenizer while resizing the model embedding layer to match, handling the case where new tokens must be initialized to the mean of existing embeddings so training starts from a reasonable point. The SupervisedDataset class loads JSON conversation data, runs preprocessing, and stores tokenized examples. The DataCollatorForSupervisedDataset handles dynamic padding of variable-length sequences within a batch, replacing padding positions in labels with IGNORE_INDEX (-100) so they do not contribute to the loss. The make_supervised_data_module(tokenizer, data_args) factory constructs the dataset and collator. The train() entry point parses ModelArguments, DataArguments, and TrainingArguments, loads the Flan-T5 model and tokenizer, optionally resizes embeddings, creates the data module, and launches the HuggingFace Trainer.
Usage
Use this when fine-tuning Flan-T5 models (e.g., google/flan-t5-xl, google/flan-t5-xxl) on multi-turn conversation data. This script handles the seq2seq-specific formatting where human turns become encoder inputs and assistant turns become decoder targets.
Code Reference
Source Location
Key Functions
| Function |
Description
|
| train() |
Main entry point: loads Flan-T5 model/tokenizer, preprocesses Q&A pairs, trains with HuggingFace Trainer
|
| smart_tokenizer_and_embedding_resize(special_tokens_dict, other_tokens, tokenizer, model) |
Safely adds special tokens and resizes model embeddings, initializing new embeddings to the mean of existing ones
|
| preprocess(sources, tokenizer) |
Converts multi-turn conversations into tokenized encoder-input / decoder-target pairs
|
| _tokenize_fn(strings, tokenizer) |
Tokenizes a list of strings and returns input_ids, labels, and attention masks
|
| _form_qa(sources) |
Extracts alternating human/assistant turns from conversations and pairs them as question-answer tuples
|
| _add_speaker_and_signal(header, source) |
Prepends speaker role identifiers and signal markers to each conversation turn
|
| make_supervised_data_module(tokenizer, data_args) |
Factory function that builds the SupervisedDataset and DataCollatorForSupervisedDataset
|
Classes
| Class |
Description
|
| SupervisedDataset |
Loads and tokenizes JSON conversation data into encoder-input and decoder-target pairs at initialization
|
| DataCollatorForSupervisedDataset |
Dynamically pads variable-length sequences in a batch, masking padding positions in labels with IGNORE_INDEX (-100)
|
Signature
def train():
...
def smart_tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
other_tokens: List,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
):
...
Import
from fastchat.train.train_flant5 import train, smart_tokenizer_and_embedding_resize
I/O Contract
Inputs
| Name |
Type |
Required |
Description
|
| --model_name_or_path |
str |
Yes |
HuggingFace model path for a Flan-T5 checkpoint (e.g., google/flan-t5-xl)
|
| --data_path |
str |
Yes |
Path to training JSON data in ShareGPT conversation format
|
| --output_dir |
str |
Yes |
Directory for saving model checkpoints and final weights
|
| --num_train_epochs |
int |
No |
Number of training epochs (default from TrainingArguments)
|
| --per_device_train_batch_size |
int |
No |
Batch size per GPU device during training
|
| --learning_rate |
float |
No |
Peak learning rate for the optimizer
|
| --model_max_length |
int |
No |
Maximum sequence length for tokenization (truncates longer sequences)
|
Outputs
| Name |
Type |
Description
|
| checkpoints |
Files |
Model checkpoints saved in output_dir at configured intervals
|
| final_model |
Files |
Final Flan-T5 model weights and tokenizer saved at end of training
|
| trainer_state |
JSON |
Training state including loss curves, learning rate schedule, and metrics
|
Usage Examples
# Fine-tune Flan-T5-XL on conversation data with 4 GPUs
torchrun --nproc_per_node=4 -m fastchat.train.train_flant5 \
--model_name_or_path google/flan-t5-xl \
--data_path data/dummy_conversation.json \
--output_dir ./output_flant5 \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--learning_rate 2e-5 \
--bf16 True \
--model_max_length 2048
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.