Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft LoRA Run SWAG

From Leeroopedia


Template:Implementation metadata

Overview

run_swag.py is a modern fine-tuning script for multiple choice tasks on the SWAG dataset using the HuggingFace Trainer API with AutoModelForMultipleChoice.

Description

This script fine-tunes pretrained transformer models on the SWAG (Situations With Adversarial Generations) commonsense reasoning benchmark, where the model must select the most plausible continuation of a given sentence from four candidate endings. It uses the modern HuggingFace datasets library for data loading and the Trainer class for the training loop.

The script defines three dataclass-based argument groups (ModelArguments, DataTrainingArguments, TrainingArguments) parsed via HfArgumentParser. It includes a custom DataCollatorForMultipleChoice that dynamically pads inputs by flattening multiple choice options, padding, then un-flattening back to (batch_size, num_choices, seq_len) shape. A check_min_version("4.4.0") guard ensures compatibility with the required transformers version.

The preprocessing pipeline constructs first/second sentence pairs for each of the four SWAG endings, tokenizes them, and groups them into sets of four per example. Evaluation uses simple accuracy (argmax over choice logits vs. gold label).

Usage

Use this script when you need to:

  • Fine-tune a pretrained model (e.g., BERT, RoBERTa) on the SWAG multiple choice benchmark
  • Adapt the pipeline for custom multiple choice tasks using CSV/JSON files
  • Leverage the modern Trainer API with checkpoint resumption, distributed training, and FP16 support

Code Reference

Source Location

Property Value
File examples/NLU/examples/multiple-choice/run_swag.py
Lines 438
Module run_swag
Entry Point main()

Signature/CLI

python run_swag.py \
    --model_name_or_path MODEL_NAME \
    --output_dir OUTPUT_DIR \
    --do_train \
    --do_eval \
    [--train_file TRAIN_FILE] \
    [--validation_file VALIDATION_FILE] \
    [--max_seq_length MAX_SEQ_LENGTH] \
    [--pad_to_max_length] \
    [--preprocessing_num_workers NUM_WORKERS] \
    [--max_train_samples N] \
    [--max_val_samples N] \
    [--per_device_train_batch_size BATCH_SIZE] \
    [--learning_rate LR] \
    [--num_train_epochs EPOCHS] \
    [--fp16]

Import

from transformers import (
    AutoConfig,
    AutoModelForMultipleChoice,
    AutoTokenizer,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    default_data_collator,
    set_seed,
)
from datasets import load_dataset

I/O Contract

Inputs

Parameter Type Required Default Description
--model_name_or_path str Yes - Pretrained model name or path (e.g., bert-base-uncased)
--output_dir str Yes - Directory for model checkpoints and results
--do_train flag No False Enable training
--do_eval flag No False Enable evaluation
--train_file str No None Custom CSV/JSON training file (if not using SWAG from hub)
--validation_file str No None Custom CSV/JSON validation file
--max_seq_length int No None Max tokenized sequence length (auto-detects from model, capped at 1024)
--pad_to_max_length flag No False Pad all samples to max length (required for TPU)
--max_train_samples int No None Truncate training set size for debugging
--max_val_samples int No None Truncate validation set size for debugging

Outputs

Output Location Description
Trained model {output_dir}/ Saved model weights, config, and tokenizer
Training metrics {output_dir}/train_results.json Loss, runtime, samples per second
Evaluation metrics {output_dir}/eval_results.json Accuracy on validation set
Trainer state {output_dir}/trainer_state.json Checkpoint and training state info

Usage Examples

Fine-tune BERT on SWAG

python examples/NLU/examples/multiple-choice/run_swag.py \
    --model_name_or_path bert-base-uncased \
    --do_train \
    --do_eval \
    --output_dir /tmp/swag_output \
    --per_device_train_batch_size 16 \
    --learning_rate 5e-5 \
    --num_train_epochs 3 \
    --max_seq_length 128

Fine-tune with custom data and FP16

python examples/NLU/examples/multiple-choice/run_swag.py \
    --model_name_or_path roberta-base \
    --train_file /path/to/train.json \
    --validation_file /path/to/val.json \
    --do_train \
    --do_eval \
    --output_dir /tmp/custom_mc_output \
    --pad_to_max_length \
    --fp16 \
    --max_train_samples 1000

Load from JSON config

python examples/NLU/examples/multiple-choice/run_swag.py config.json

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment