Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Microsoft LoRA Run TF Text Classification

From Leeroopedia


Template:Implementation metadata

Overview

run_tf_text_classification.py is a TensorFlow-based text classification fine-tuning script using TFAutoModelForSequenceClassification and the TFTrainer API with custom CSV data loading.

Description

This script fine-tunes TensorFlow transformer models on text classification tasks using CSV data files. Unlike the PyTorch-based scripts that use the datasets library directly, this script includes a custom get_tfds() function that handles the complete data pipeline from CSV files to tf.data.Dataset objects.

Key implementation details:

  • get_tfds() function: The core data loading utility that:
    • Loads CSV files using datasets.load_dataset("csv", data_files=files)
    • Identifies feature columns and the label column by index (label_column_id)
    • Automatically detects single-sentence vs. sentence-pair classification based on the number of non-label columns
    • Tokenizes using tokenizer.batch_encode_plus() with max_length padding and truncation
    • Creates tf.data.Dataset from generators with proper tensor shapes and types
    • Returns (train_ds, val_ds, test_ds, label2id)
  • TFTrainer and TFTrainingArguments: Uses the TensorFlow-specific trainer that operates within training_args.strategy.scope() for TPU/multi-GPU support.
  • Label handling: Automatically builds label2id mapping from unique label values and configures both label2id and id2label in the model config.
  • Metrics: Simple accuracy computation via np.argmax comparison.
  • Model loading: Supports loading PyTorch models via from_pt=True when the model path contains .bin.
  • Argument structure: Uses HfArgumentParser with ModelArguments, DataTrainingArguments, and TFTrainingArguments dataclasses.

Note: This script requires TensorFlow to be installed and does not use PyTorch.

Usage

Use this script when you need to:

  • Fine-tune transformer models using the TensorFlow backend
  • Classify text from CSV files with flexible label column positioning
  • Deploy on TPU infrastructure using TF distribution strategies

Code Reference

Source Location

Property Value
File examples/NLU/examples/text-classification/run_tf_text_classification.py
Lines 312
Module run_tf_text_classification
Entry Point main()

Signature/CLI

python run_tf_text_classification.py \
    --model_name_or_path MODEL_NAME \
    --output_dir OUTPUT_DIR \
    --label_column_id LABEL_COL_IDX \
    --do_train \
    --do_eval \
    [--train_file TRAIN_CSV] \
    [--dev_file DEV_CSV] \
    [--test_file TEST_CSV] \
    [--max_seq_length 128] \
    [--per_device_train_batch_size 32] \
    [--learning_rate 5e-5] \
    [--num_train_epochs 3]

Import

import tensorflow as tf
from transformers import (
    AutoConfig,
    AutoTokenizer,
    EvalPrediction,
    HfArgumentParser,
    PreTrainedTokenizer,
    TFAutoModelForSequenceClassification,
    TFTrainer,
    TFTrainingArguments,
)

I/O Contract

Inputs

Parameter Type Required Default Description
--model_name_or_path str Yes - Pretrained model name or path
--output_dir str Yes - Directory for model output
--label_column_id int Yes - Zero-based index of the label column in CSV
--train_file str No None Path to training CSV file
--dev_file str No None Path to development/validation CSV file
--test_file str No None Path to test CSV file
--max_seq_length int No 128 Max tokenized sequence length
--use_fast flag No False Use fast tokenizer backend
--overwrite_cache flag No False Overwrite cached data

Outputs

Output Location Description
Trained model {output_dir}/ Saved TF model and tokenizer
Evaluation results {output_dir}/eval_results.txt Accuracy metric written to file

Usage Examples

Fine-tune BERT on custom CSV classification

python examples/NLU/examples/text-classification/run_tf_text_classification.py \
    --model_name_or_path bert-base-uncased \
    --train_file /path/to/train.csv \
    --dev_file /path/to/dev.csv \
    --label_column_id 2 \
    --max_seq_length 128 \
    --do_train \
    --do_eval \
    --per_device_train_batch_size 32 \
    --learning_rate 5e-5 \
    --num_train_epochs 3 \
    --output_dir /tmp/tf_classification

Load PyTorch model weights in TensorFlow

python examples/NLU/examples/text-classification/run_tf_text_classification.py \
    --model_name_or_path /path/to/pytorch_model.bin \
    --train_file /path/to/train.csv \
    --dev_file /path/to/dev.csv \
    --label_column_id 0 \
    --do_train \
    --do_eval \
    --output_dir /tmp/tf_from_pt

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment