Implementation:Microsoft LoRA Run SWAG
Template:Implementation metadata
Overview
run_swag.py is a modern fine-tuning script for multiple choice tasks on the SWAG dataset using the HuggingFace Trainer API with AutoModelForMultipleChoice.
Description
This script fine-tunes pretrained transformer models on the SWAG (Situations With Adversarial Generations) commonsense reasoning benchmark, where the model must select the most plausible continuation of a given sentence from four candidate endings. It uses the modern HuggingFace datasets library for data loading and the Trainer class for the training loop.
The script defines three dataclass-based argument groups (ModelArguments, DataTrainingArguments, TrainingArguments) parsed via HfArgumentParser. It includes a custom DataCollatorForMultipleChoice that dynamically pads inputs by flattening multiple choice options, padding, then un-flattening back to (batch_size, num_choices, seq_len) shape. A check_min_version("4.4.0") guard ensures compatibility with the required transformers version.
The preprocessing pipeline constructs first/second sentence pairs for each of the four SWAG endings, tokenizes them, and groups them into sets of four per example. Evaluation uses simple accuracy (argmax over choice logits vs. gold label).
Usage
Use this script when you need to:
- Fine-tune a pretrained model (e.g., BERT, RoBERTa) on the SWAG multiple choice benchmark
- Adapt the pipeline for custom multiple choice tasks using CSV/JSON files
- Leverage the modern Trainer API with checkpoint resumption, distributed training, and FP16 support
Code Reference
Source Location
| Property | Value |
|---|---|
| File | examples/NLU/examples/multiple-choice/run_swag.py
|
| Lines | 438 |
| Module | run_swag
|
| Entry Point | main()
|
Signature/CLI
python run_swag.py \
--model_name_or_path MODEL_NAME \
--output_dir OUTPUT_DIR \
--do_train \
--do_eval \
[--train_file TRAIN_FILE] \
[--validation_file VALIDATION_FILE] \
[--max_seq_length MAX_SEQ_LENGTH] \
[--pad_to_max_length] \
[--preprocessing_num_workers NUM_WORKERS] \
[--max_train_samples N] \
[--max_val_samples N] \
[--per_device_train_batch_size BATCH_SIZE] \
[--learning_rate LR] \
[--num_train_epochs EPOCHS] \
[--fp16]
Import
from transformers import (
AutoConfig,
AutoModelForMultipleChoice,
AutoTokenizer,
HfArgumentParser,
Trainer,
TrainingArguments,
default_data_collator,
set_seed,
)
from datasets import load_dataset
I/O Contract
Inputs
| Parameter | Type | Required | Default | Description |
|---|---|---|---|---|
--model_name_or_path |
str | Yes | - | Pretrained model name or path (e.g., bert-base-uncased)
|
--output_dir |
str | Yes | - | Directory for model checkpoints and results |
--do_train |
flag | No | False | Enable training |
--do_eval |
flag | No | False | Enable evaluation |
--train_file |
str | No | None | Custom CSV/JSON training file (if not using SWAG from hub) |
--validation_file |
str | No | None | Custom CSV/JSON validation file |
--max_seq_length |
int | No | None | Max tokenized sequence length (auto-detects from model, capped at 1024) |
--pad_to_max_length |
flag | No | False | Pad all samples to max length (required for TPU) |
--max_train_samples |
int | No | None | Truncate training set size for debugging |
--max_val_samples |
int | No | None | Truncate validation set size for debugging |
Outputs
| Output | Location | Description |
|---|---|---|
| Trained model | {output_dir}/ |
Saved model weights, config, and tokenizer |
| Training metrics | {output_dir}/train_results.json |
Loss, runtime, samples per second |
| Evaluation metrics | {output_dir}/eval_results.json |
Accuracy on validation set |
| Trainer state | {output_dir}/trainer_state.json |
Checkpoint and training state info |
Usage Examples
Fine-tune BERT on SWAG
python examples/NLU/examples/multiple-choice/run_swag.py \
--model_name_or_path bert-base-uncased \
--do_train \
--do_eval \
--output_dir /tmp/swag_output \
--per_device_train_batch_size 16 \
--learning_rate 5e-5 \
--num_train_epochs 3 \
--max_seq_length 128
Fine-tune with custom data and FP16
python examples/NLU/examples/multiple-choice/run_swag.py \
--model_name_or_path roberta-base \
--train_file /path/to/train.json \
--validation_file /path/to/val.json \
--do_train \
--do_eval \
--output_dir /tmp/custom_mc_output \
--pad_to_max_length \
--fp16 \
--max_train_samples 1000
Load from JSON config
python examples/NLU/examples/multiple-choice/run_swag.py config.json
Related Pages
- Environment:Microsoft_LoRA_NLU_Conda_Environment
- Implementation:Microsoft_LoRA_Utils_Multiple_Choice - Utility classes for multiple choice data processing