Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft LoRA Run Summarization

From Leeroopedia


Template:Implementation metadata

Overview

run_summarization.py is a sequence-to-sequence fine-tuning script for text summarization using AutoModelForSeq2SeqLM, Seq2SeqTrainer, and ROUGE metrics with NLTK sentence tokenization.

Description

This script fine-tunes encoder-decoder models (e.g., BART, T5, Pegasus) on abstractive summarization datasets. It uses the Seq2SeqTrainer and Seq2SeqTrainingArguments which extend the standard Trainer with generation-specific parameters like predict_with_generate, num_beams, and max_length for the model.generate() call during evaluation.

Key implementation details:

  • NLTK dependency: Downloads the punkt tokenizer at startup (with FileLock for distributed safety and offline mode detection). Uses nltk.sent_tokenize() to split predictions and labels into sentences separated by newlines, as required by the rougeLSum metric.
  • Dataset name mapping: A summarization_name_mapping dictionary maps known dataset names to their (text_column, summary_column) pairs, supporting datasets like cnn_dailymail ("article", "highlights"), xsum ("document", "summary"), samsum ("dialogue", "summary"), and others.
  • T5 source prefix: Warns if running a T5 model without --source_prefix (e.g., "summarize: ").
  • Data collation: Uses DataCollatorForSeq2Seq for dynamic padding with label pad token ID set to -100 when ignore_pad_token_for_loss is enabled.
  • Metrics: Computes ROUGE scores (rouge1, rouge2, rougeL, rougeLSum) using the rouge metric from the datasets library, with stemming enabled. Reports F-measure scaled to 0-100 range plus average generation length.
  • Three-phase pipeline: Supports do_train, do_eval, and do_predict for training, validation, and test set prediction with decoded text output saved to test_generations.txt.

Usage

Use this script when you need to:

  • Fine-tune seq2seq models on summarization benchmarks (CNN/DailyMail, XSum, etc.)
  • Train on custom summarization datasets in CSV/JSON format
  • Generate and evaluate summaries using ROUGE metrics with beam search

Code Reference

Source Location

Property Value
File examples/NLU/examples/seq2seq/run_summarization.py
Lines 595
Module run_summarization
Entry Point main()

Signature/CLI

python run_summarization.py \
    --model_name_or_path MODEL_NAME \
    --dataset_name DATASET_NAME \
    --output_dir OUTPUT_DIR \
    --do_train \
    --do_eval \
    [--do_predict] \
    [--dataset_config_name CONFIG] \
    [--text_column TEXT_COL] \
    [--summary_column SUMMARY_COL] \
    [--source_prefix "summarize: "] \
    [--max_source_length 1024] \
    [--max_target_length 128] \
    [--val_max_target_length 128] \
    [--num_beams 4] \
    [--pad_to_max_length] \
    [--ignore_pad_token_for_loss] \
    [--predict_with_generate]

Import

from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    default_data_collator,
    set_seed,
)
from datasets import load_dataset, load_metric
import nltk

I/O Contract

Inputs

Parameter Type Required Default Description
--model_name_or_path str Yes - Pretrained seq2seq model (e.g., facebook/bart-large-cnn)
--output_dir str Yes - Directory for checkpoints and results
--dataset_name str No None HuggingFace dataset name (e.g., cnn_dailymail)
--text_column str No None Source text column name (auto-detected from dataset mapping)
--summary_column str No None Summary column name (auto-detected from dataset mapping)
--source_prefix str No None Prefix for source text (e.g., "summarize: " for T5)
--max_source_length int No 1024 Max source tokenized length
--max_target_length int No 128 Max target tokenized length for training
--val_max_target_length int No max_target_length Max target length for eval/predict generation
--num_beams int No None Beam search width for evaluation generation
--ignore_pad_token_for_loss bool No True Replace pad tokens with -100 in labels

Outputs

Output Location Description
Trained model {output_dir}/ Saved model, config, and tokenizer
Training metrics {output_dir}/train_results.json Training loss and throughput
Evaluation metrics {output_dir}/eval_results.json ROUGE-1, ROUGE-2, ROUGE-L, ROUGE-Lsum, gen_len
Test metrics {output_dir}/test_results.json ROUGE scores on test set
Test generations {output_dir}/test_generations.txt Decoded summaries, one per line

Usage Examples

Fine-tune BART on CNN/DailyMail

python examples/NLU/examples/seq2seq/run_summarization.py \
    --model_name_or_path facebook/bart-large \
    --dataset_name cnn_dailymail \
    --dataset_config_name "3.0.0" \
    --do_train \
    --do_eval \
    --predict_with_generate \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --learning_rate 3e-5 \
    --num_train_epochs 3 \
    --max_source_length 1024 \
    --max_target_length 128 \
    --num_beams 4 \
    --output_dir /tmp/bart_cnn_output

Fine-tune T5 on XSum with source prefix

python examples/NLU/examples/seq2seq/run_summarization.py \
    --model_name_or_path t5-base \
    --dataset_name xsum \
    --do_train \
    --do_eval \
    --do_predict \
    --source_prefix "summarize: " \
    --predict_with_generate \
    --per_device_train_batch_size 8 \
    --output_dir /tmp/t5_xsum_output

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment