Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft LoRA Legacy Seq2Seq Utils

From Leeroopedia


Template:Implementation metadata

Overview

Comprehensive utility module for seq2seq fine-tuning providing dataset classes, data collators, custom samplers, loss functions, metric computation (ROUGE, BLEU), and parameter freezing helpers.

Description

utils.py is a utility library in the legacy seq2seq examples directory of the Microsoft LoRA NLU repository. It provides the full data processing and evaluation infrastructure for seq2seq fine-tuning of summarization and translation models. The module contains:

  • Dataset classes: AbstractSeq2SeqDataset (base class with sortish sampling and dynamic batching), Seq2SeqDataset (using prepare_seq2seq_batch), and LegacySeq2SeqDataset (using manual tokenization). All datasets read from paired .source/.target line-indexed text files via linecache.
  • Data collator: Seq2SeqDataCollator that handles tokenization, padding, trimming, and decoder input preparation with support for T5-style and BART-style right-shifting of labels.
  • Custom samplers: SortishSampler (sorts by source length with randomness for efficiency) and DistributedSortishSampler (distributed-aware variant).
  • Loss functions: label_smoothed_nll_loss() adapted from fairseq for label smoothing regularization.
  • Metrics: calculate_rouge() using rouge_scorer with bootstrap aggregation, and calculate_bleu() using sacrebleu.
  • Parameter freezing: freeze_params(), freeze_embeds() (model-type-aware for T5/mT5, FSMT, and BART), assert_all_frozen(), assert_not_all_frozen().
  • I/O utilities: save_json(), load_json(), pickle_load(), pickle_save(), write_txt_file(), save_git_info().

This module is part of the HuggingFace Transformers library (legacy examples) bundled in the Microsoft LoRA repository.

⚠️ DEPRECATED: This file resides in the legacy/ directory and is not actively maintained. Prefer modern equivalents where available.

Usage

Use this module's classes and functions when building seq2seq training pipelines for summarization or translation. Import Seq2SeqDataset and Seq2SeqDataCollator for data loading, build_compute_metrics_fn for evaluation metrics, and the freezing utilities for transfer learning with frozen components.

Code Reference

Source Location

Property Value
File path examples/NLU/examples/legacy/seq2seq/utils.py
Lines 664
Module utils (within seq2seq directory)

Key Classes

Name Type Signature / Description
AbstractSeq2SeqDataset class __init__(self, tokenizer, data_dir, max_source_length, max_target_length, type_path="train", n_obs=None, prefix="", **dataset_kwargs) -- base class with sortish/dynamic sampling
Seq2SeqDataset class Returns {"tgt_texts": str, "src_texts": str, "id": int} per item; uses prepare_seq2seq_batch in collate_fn
LegacySeq2SeqDataset class Returns {"input_ids": Tensor, "attention_mask": Tensor, "labels": Tensor} per item; manual tokenization with encode_line
Seq2SeqDataCollator class __init__(self, tokenizer, data_args, decoder_start_token_id, tpu_num_cores=None) -- callable collator handling T5 and BART decoder input preparation
SortishSampler class __init__(self, data, batch_size, shuffle=True) -- sorts by source length with randomness for efficient batching
DistributedSortishSampler class __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True, shuffle=True)

Key Functions

Name Signature Description
label_smoothed_nll_loss label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100) Label smoothing loss from fairseq; returns (smoothed_loss, nll_loss)
calculate_rouge calculate_rouge(pred_lns, tgt_lns, use_stemmer=True, rouge_keys=ROUGE_KEYS, return_precision_and_recall=False, bootstrap_aggregation=True, newline_sep=True) Computes ROUGE-1, ROUGE-2, ROUGE-L, ROUGE-Lsum scores
calculate_bleu calculate_bleu(output_lns, refs_lns, **kwargs) Computes corpus BLEU score using sacrebleu
build_compute_metrics_fn build_compute_metrics_fn(task_name, tokenizer) Returns a metrics function: ROUGE for summarization, BLEU for translation
trim_batch trim_batch(input_ids, pad_token_id, attention_mask=None) Removes columns that are entirely padding
freeze_params freeze_params(model: nn.Module) Sets requires_grad=False for all parameters
freeze_embeds freeze_embeds(model) Freezes embedding layers based on model type (T5/mT5, FSMT, BART)
assert_all_frozen assert_all_frozen(model) Asserts no parameters require gradients
use_task_specific_params use_task_specific_params(model, task) Updates model config with task-specific params from model.config.task_specific_params
lmap lmap(f, x) Shorthand for list(map(f, x))
save_json save_json(content, path, indent=4, **json_dump_kwargs) Writes JSON to file
check_output_dir check_output_dir(args, expected_items=0) Validates output directory state before training

Import Usage

from utils import (
    Seq2SeqDataCollator,
    Seq2SeqDataset,
    assert_all_frozen,
    build_compute_metrics_fn,
    check_output_dir,
    freeze_embeds,
    freeze_params,
    label_smoothed_nll_loss,
    calculate_rouge,
    calculate_bleu,
    lmap,
    save_json,
    use_task_specific_params,
    write_txt_file,
)

I/O Contract

Inputs (Seq2SeqDataset)

Input Type Description
tokenizer PreTrainedTokenizer HuggingFace tokenizer for encoding source and target
data_dir str Directory containing {type_path}.source and {type_path}.target files
max_source_length int Maximum source tokenization length
max_target_length int Maximum target tokenization length
type_path str (default "train") Split name: "train", "val", or "test"
n_obs Optional[int] Limit number of observations (None for all)
prefix str Prefix to prepend to source lines (e.g., "summarize: " for T5)

Inputs (label_smoothed_nll_loss)

Input Type Description
lprobs Tensor Log probabilities from model output, shape (batch, seq_len, vocab_size)
target Tensor Target token IDs, shape (batch, seq_len)
epsilon float Label smoothing factor
ignore_index int (default -100) Index to ignore in loss computation (padding)

Outputs

Output Type Description
Seq2SeqDataset[i] Dict[str, str] {"tgt_texts": str, "src_texts": str, "id": int}
Seq2SeqDataCollator(batch) Dict[str, Tensor] {"input_ids", "attention_mask", "decoder_input_ids", "labels"}
label_smoothed_nll_loss() Tuple[Tensor, Tensor] (smoothed_loss, nll_loss)
calculate_rouge() Dict[str, float] {"rouge1": float, "rouge2": float, "rougeL": float, "rougeLsum": float}
calculate_bleu() Dict[str, float] {"bleu": float}

Usage Examples

Creating a Seq2SeqDataset

from transformers import AutoTokenizer
from utils import Seq2SeqDataset

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

dataset = Seq2SeqDataset(
    tokenizer=tokenizer,
    data_dir="/data/cnn_dm/",
    max_source_length=1024,
    max_target_length=128,
    type_path="train",
    prefix="",
)

print(f"Dataset size: {len(dataset)}")
print(f"First example: {dataset[0]}")

Computing ROUGE Metrics

from utils import calculate_rouge

predictions = ["The cat sat on the mat.", "Dogs are loyal animals."]
references = ["A cat was sitting on a mat.", "Dogs are very loyal pets."]

scores = calculate_rouge(predictions, references)
print(scores)
# {"rouge1": 65.2, "rouge2": 33.1, "rougeL": 60.0, "rougeLsum": 60.0}

Using Label Smoothing Loss

import torch
from utils import label_smoothed_nll_loss

# lprobs: (batch_size, seq_len, vocab_size)
lprobs = torch.randn(2, 10, 50265).log_softmax(dim=-1)
target = torch.randint(0, 50265, (2, 10))

loss, nll_loss = label_smoothed_nll_loss(lprobs, target, epsilon=0.1)
print(f"Smoothed loss: {loss.item()}, NLL loss: {nll_loss.item()}")

Freezing Embeddings

from transformers import AutoModelForSeq2SeqLM
from utils import freeze_embeds, assert_all_frozen

model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large")
freeze_embeds(model)
# Verify encoder embeddings are frozen
# assert_all_frozen(model.model.encoder.embed_tokens)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment