Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lm sys FastChat Train FlanT5

From Leeroopedia


Knowledge Sources
Domains Training, NLP
Last Updated 2026-02-07 06:00 GMT

Overview

Supervised fine-tuning pipeline for Flan-T5 sequence-to-sequence models on multi-turn conversation data.

Description

Train FlanT5 implements a supervised fine-tuning pipeline specifically designed for encoder-decoder (seq2seq) Flan-T5 models. Unlike causal LM training scripts that produce a single concatenated sequence, this module reformats multi-turn conversations into explicit question-answer pairs suitable for the T5 encoder-decoder architecture. The _form_qa(sources) helper extracts alternating human/assistant turns and pairs them, while _add_speaker_and_signal(header, source) prepends speaker role tokens and signal markers to each turn. The preprocess(sources, tokenizer) function orchestrates the full preprocessing pipeline: it calls _form_qa to produce Q&A pairs, then uses _tokenize_fn(strings, tokenizer) to tokenize the input (question) and target (answer) sequences separately for the encoder and decoder. The smart_tokenizer_and_embedding_resize(special_tokens_dict, other_tokens, tokenizer, model) utility safely adds special tokens and other tokens to the tokenizer while resizing the model embedding layer to match, handling the case where new tokens must be initialized to the mean of existing embeddings so training starts from a reasonable point. The SupervisedDataset class loads JSON conversation data, runs preprocessing, and stores tokenized examples. The DataCollatorForSupervisedDataset handles dynamic padding of variable-length sequences within a batch, replacing padding positions in labels with IGNORE_INDEX (-100) so they do not contribute to the loss. The make_supervised_data_module(tokenizer, data_args) factory constructs the dataset and collator. The train() entry point parses ModelArguments, DataArguments, and TrainingArguments, loads the Flan-T5 model and tokenizer, optionally resizes embeddings, creates the data module, and launches the HuggingFace Trainer.

Usage

Use this when fine-tuning Flan-T5 models (e.g., google/flan-t5-xl, google/flan-t5-xxl) on multi-turn conversation data. This script handles the seq2seq-specific formatting where human turns become encoder inputs and assistant turns become decoder targets.

Code Reference

Source Location

Key Functions

Function Description
train() Main entry point: loads Flan-T5 model/tokenizer, preprocesses Q&A pairs, trains with HuggingFace Trainer
smart_tokenizer_and_embedding_resize(special_tokens_dict, other_tokens, tokenizer, model) Safely adds special tokens and resizes model embeddings, initializing new embeddings to the mean of existing ones
preprocess(sources, tokenizer) Converts multi-turn conversations into tokenized encoder-input / decoder-target pairs
_tokenize_fn(strings, tokenizer) Tokenizes a list of strings and returns input_ids, labels, and attention masks
_form_qa(sources) Extracts alternating human/assistant turns from conversations and pairs them as question-answer tuples
_add_speaker_and_signal(header, source) Prepends speaker role identifiers and signal markers to each conversation turn
make_supervised_data_module(tokenizer, data_args) Factory function that builds the SupervisedDataset and DataCollatorForSupervisedDataset

Classes

Class Description
SupervisedDataset Loads and tokenizes JSON conversation data into encoder-input and decoder-target pairs at initialization
DataCollatorForSupervisedDataset Dynamically pads variable-length sequences in a batch, masking padding positions in labels with IGNORE_INDEX (-100)

Signature

def train():
    ...

def smart_tokenizer_and_embedding_resize(
    special_tokens_dict: Dict,
    other_tokens: List,
    tokenizer: transformers.PreTrainedTokenizer,
    model: transformers.PreTrainedModel,
):
    ...

Import

from fastchat.train.train_flant5 import train, smart_tokenizer_and_embedding_resize

I/O Contract

Inputs

Name Type Required Description
--model_name_or_path str Yes HuggingFace model path for a Flan-T5 checkpoint (e.g., google/flan-t5-xl)
--data_path str Yes Path to training JSON data in ShareGPT conversation format
--output_dir str Yes Directory for saving model checkpoints and final weights
--num_train_epochs int No Number of training epochs (default from TrainingArguments)
--per_device_train_batch_size int No Batch size per GPU device during training
--learning_rate float No Peak learning rate for the optimizer
--model_max_length int No Maximum sequence length for tokenization (truncates longer sequences)

Outputs

Name Type Description
checkpoints Files Model checkpoints saved in output_dir at configured intervals
final_model Files Final Flan-T5 model weights and tokenizer saved at end of training
trainer_state JSON Training state including loss curves, learning rate schedule, and metrics

Usage Examples

# Fine-tune Flan-T5-XL on conversation data with 4 GPUs
torchrun --nproc_per_node=4 -m fastchat.train.train_flant5 \
    --model_name_or_path google/flan-t5-xl \
    --data_path data/dummy_conversation.json \
    --output_dir ./output_flant5 \
    --num_train_epochs 3 \
    --per_device_train_batch_size 4 \
    --learning_rate 2e-5 \
    --bf16 True \
    --model_max_length 2048

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment