Implementation:Microsoft LoRA Run TF Text Classification
Template:Implementation metadata
Overview
run_tf_text_classification.py is a TensorFlow-based text classification fine-tuning script using TFAutoModelForSequenceClassification and the TFTrainer API with custom CSV data loading.
Description
This script fine-tunes TensorFlow transformer models on text classification tasks using CSV data files. Unlike the PyTorch-based scripts that use the datasets library directly, this script includes a custom get_tfds() function that handles the complete data pipeline from CSV files to tf.data.Dataset objects.
Key implementation details:
get_tfds()function: The core data loading utility that:- Loads CSV files using
datasets.load_dataset("csv", data_files=files) - Identifies feature columns and the label column by index (
label_column_id) - Automatically detects single-sentence vs. sentence-pair classification based on the number of non-label columns
- Tokenizes using
tokenizer.batch_encode_plus()with max_length padding and truncation - Creates
tf.data.Datasetfrom generators with proper tensor shapes and types - Returns
(train_ds, val_ds, test_ds, label2id)
- Loads CSV files using
- TFTrainer and TFTrainingArguments: Uses the TensorFlow-specific trainer that operates within
training_args.strategy.scope()for TPU/multi-GPU support. - Label handling: Automatically builds
label2idmapping from unique label values and configures bothlabel2idandid2labelin the model config. - Metrics: Simple accuracy computation via
np.argmaxcomparison. - Model loading: Supports loading PyTorch models via
from_pt=Truewhen the model path contains.bin. - Argument structure: Uses
HfArgumentParserwithModelArguments,DataTrainingArguments, andTFTrainingArgumentsdataclasses.
Note: This script requires TensorFlow to be installed and does not use PyTorch.
Usage
Use this script when you need to:
- Fine-tune transformer models using the TensorFlow backend
- Classify text from CSV files with flexible label column positioning
- Deploy on TPU infrastructure using TF distribution strategies
Code Reference
Source Location
| Property | Value |
|---|---|
| File | examples/NLU/examples/text-classification/run_tf_text_classification.py
|
| Lines | 312 |
| Module | run_tf_text_classification
|
| Entry Point | main()
|
Signature/CLI
python run_tf_text_classification.py \
--model_name_or_path MODEL_NAME \
--output_dir OUTPUT_DIR \
--label_column_id LABEL_COL_IDX \
--do_train \
--do_eval \
[--train_file TRAIN_CSV] \
[--dev_file DEV_CSV] \
[--test_file TEST_CSV] \
[--max_seq_length 128] \
[--per_device_train_batch_size 32] \
[--learning_rate 5e-5] \
[--num_train_epochs 3]
Import
import tensorflow as tf
from transformers import (
AutoConfig,
AutoTokenizer,
EvalPrediction,
HfArgumentParser,
PreTrainedTokenizer,
TFAutoModelForSequenceClassification,
TFTrainer,
TFTrainingArguments,
)
I/O Contract
Inputs
| Parameter | Type | Required | Default | Description |
|---|---|---|---|---|
--model_name_or_path |
str | Yes | - | Pretrained model name or path |
--output_dir |
str | Yes | - | Directory for model output |
--label_column_id |
int | Yes | - | Zero-based index of the label column in CSV |
--train_file |
str | No | None | Path to training CSV file |
--dev_file |
str | No | None | Path to development/validation CSV file |
--test_file |
str | No | None | Path to test CSV file |
--max_seq_length |
int | No | 128 | Max tokenized sequence length |
--use_fast |
flag | No | False | Use fast tokenizer backend |
--overwrite_cache |
flag | No | False | Overwrite cached data |
Outputs
| Output | Location | Description |
|---|---|---|
| Trained model | {output_dir}/ |
Saved TF model and tokenizer |
| Evaluation results | {output_dir}/eval_results.txt |
Accuracy metric written to file |
Usage Examples
Fine-tune BERT on custom CSV classification
python examples/NLU/examples/text-classification/run_tf_text_classification.py \
--model_name_or_path bert-base-uncased \
--train_file /path/to/train.csv \
--dev_file /path/to/dev.csv \
--label_column_id 2 \
--max_seq_length 128 \
--do_train \
--do_eval \
--per_device_train_batch_size 32 \
--learning_rate 5e-5 \
--num_train_epochs 3 \
--output_dir /tmp/tf_classification
Load PyTorch model weights in TensorFlow
python examples/NLU/examples/text-classification/run_tf_text_classification.py \
--model_name_or_path /path/to/pytorch_model.bin \
--train_file /path/to/train.csv \
--dev_file /path/to/dev.csv \
--label_column_id 0 \
--do_train \
--do_eval \
--output_dir /tmp/tf_from_pt
Related Pages
- Environment:Microsoft_LoRA_NLU_Conda_Environment
- Implementation:Microsoft_LoRA_Run_GLUE_No_Trainer - PyTorch GLUE classification without Trainer
- Implementation:Microsoft_LoRA_Run_XNLI - Multilingual classification with PyTorch Trainer