Implementation:Microsoft LoRA Run Summarization
Template:Implementation metadata
Overview
run_summarization.py is a sequence-to-sequence fine-tuning script for text summarization using AutoModelForSeq2SeqLM, Seq2SeqTrainer, and ROUGE metrics with NLTK sentence tokenization.
Description
This script fine-tunes encoder-decoder models (e.g., BART, T5, Pegasus) on abstractive summarization datasets. It uses the Seq2SeqTrainer and Seq2SeqTrainingArguments which extend the standard Trainer with generation-specific parameters like predict_with_generate, num_beams, and max_length for the model.generate() call during evaluation.
Key implementation details:
- NLTK dependency: Downloads the
punkttokenizer at startup (withFileLockfor distributed safety and offline mode detection). Usesnltk.sent_tokenize()to split predictions and labels into sentences separated by newlines, as required by therougeLSummetric. - Dataset name mapping: A
summarization_name_mappingdictionary maps known dataset names to their (text_column, summary_column) pairs, supporting datasets likecnn_dailymail("article", "highlights"),xsum("document", "summary"),samsum("dialogue", "summary"), and others. - T5 source prefix: Warns if running a T5 model without
--source_prefix(e.g.,"summarize: "). - Data collation: Uses
DataCollatorForSeq2Seqfor dynamic padding with label pad token ID set to -100 whenignore_pad_token_for_lossis enabled. - Metrics: Computes ROUGE scores (rouge1, rouge2, rougeL, rougeLSum) using the
rougemetric from the datasets library, with stemming enabled. Reports F-measure scaled to 0-100 range plus average generation length. - Three-phase pipeline: Supports
do_train,do_eval, anddo_predictfor training, validation, and test set prediction with decoded text output saved totest_generations.txt.
Usage
Use this script when you need to:
- Fine-tune seq2seq models on summarization benchmarks (CNN/DailyMail, XSum, etc.)
- Train on custom summarization datasets in CSV/JSON format
- Generate and evaluate summaries using ROUGE metrics with beam search
Code Reference
Source Location
| Property | Value |
|---|---|
| File | examples/NLU/examples/seq2seq/run_summarization.py
|
| Lines | 595 |
| Module | run_summarization
|
| Entry Point | main()
|
Signature/CLI
python run_summarization.py \
--model_name_or_path MODEL_NAME \
--dataset_name DATASET_NAME \
--output_dir OUTPUT_DIR \
--do_train \
--do_eval \
[--do_predict] \
[--dataset_config_name CONFIG] \
[--text_column TEXT_COL] \
[--summary_column SUMMARY_COL] \
[--source_prefix "summarize: "] \
[--max_source_length 1024] \
[--max_target_length 128] \
[--val_max_target_length 128] \
[--num_beams 4] \
[--pad_to_max_length] \
[--ignore_pad_token_for_loss] \
[--predict_with_generate]
Import
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
default_data_collator,
set_seed,
)
from datasets import load_dataset, load_metric
import nltk
I/O Contract
Inputs
| Parameter | Type | Required | Default | Description |
|---|---|---|---|---|
--model_name_or_path |
str | Yes | - | Pretrained seq2seq model (e.g., facebook/bart-large-cnn)
|
--output_dir |
str | Yes | - | Directory for checkpoints and results |
--dataset_name |
str | No | None | HuggingFace dataset name (e.g., cnn_dailymail)
|
--text_column |
str | No | None | Source text column name (auto-detected from dataset mapping) |
--summary_column |
str | No | None | Summary column name (auto-detected from dataset mapping) |
--source_prefix |
str | No | None | Prefix for source text (e.g., "summarize: " for T5)
|
--max_source_length |
int | No | 1024 | Max source tokenized length |
--max_target_length |
int | No | 128 | Max target tokenized length for training |
--val_max_target_length |
int | No | max_target_length | Max target length for eval/predict generation |
--num_beams |
int | No | None | Beam search width for evaluation generation |
--ignore_pad_token_for_loss |
bool | No | True | Replace pad tokens with -100 in labels |
Outputs
| Output | Location | Description |
|---|---|---|
| Trained model | {output_dir}/ |
Saved model, config, and tokenizer |
| Training metrics | {output_dir}/train_results.json |
Training loss and throughput |
| Evaluation metrics | {output_dir}/eval_results.json |
ROUGE-1, ROUGE-2, ROUGE-L, ROUGE-Lsum, gen_len |
| Test metrics | {output_dir}/test_results.json |
ROUGE scores on test set |
| Test generations | {output_dir}/test_generations.txt |
Decoded summaries, one per line |
Usage Examples
Fine-tune BART on CNN/DailyMail
python examples/NLU/examples/seq2seq/run_summarization.py \
--model_name_or_path facebook/bart-large \
--dataset_name cnn_dailymail \
--dataset_config_name "3.0.0" \
--do_train \
--do_eval \
--predict_with_generate \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--learning_rate 3e-5 \
--num_train_epochs 3 \
--max_source_length 1024 \
--max_target_length 128 \
--num_beams 4 \
--output_dir /tmp/bart_cnn_output
Fine-tune T5 on XSum with source prefix
python examples/NLU/examples/seq2seq/run_summarization.py \
--model_name_or_path t5-base \
--dataset_name xsum \
--do_train \
--do_eval \
--do_predict \
--source_prefix "summarize: " \
--predict_with_generate \
--per_device_train_batch_size 8 \
--output_dir /tmp/t5_xsum_output
Related Pages
- Environment:Microsoft_LoRA_NLU_Conda_Environment
- Implementation:Microsoft_LoRA_Run_Translation - Similar seq2seq fine-tuning for translation tasks