Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft LoRA QA Trainer

From Leeroopedia


Template:Implementation meta

Overview

The trainer_qa.py module provides a QuestionAnsweringTrainer class that extends the HuggingFace Trainer with custom evaluation and prediction methods tailored for extractive question answering tasks.

Description

Question answering tasks require special post-processing: raw model logits (start/end position scores) must be mapped back to text spans in the original context, which may have been split into overlapping features. The standard Trainer.evaluate() method does not support this two-stage pipeline. QuestionAnsweringTrainer addresses this by:

Constructor:

  • Accepts eval_examples (the original QA examples before feature creation) and post_process_function (a callable that converts raw predictions back to answer spans).

evaluate() method:

  1. Temporarily disables compute_metrics to prevent premature metric computation during the prediction loop.
  2. Runs the standard prediction_loop() to get raw predictions.
  3. Restores dataset columns that may have been removed during preprocessing (using datasets.Dataset.set_format()).
  4. Calls post_process_function(eval_examples, eval_dataset, output.predictions) to convert raw logits to answer spans.
  5. Calls compute_metrics(eval_preds) to calculate QA metrics (typically exact match and F1).
  6. Handles TPU metrics debugging via xm.master_print(met.metrics_report()) when enabled.
  7. Fires the on_evaluate callback.

predict() method:

  1. Follows the same pattern as evaluate() but for test datasets.
  2. Returns a PredictionOutput namedtuple with predictions, label IDs, and metrics.

Usage

Use this trainer when:

  • Fine-tuning models for extractive question answering (e.g., SQuAD, SQuAD v2).
  • The evaluation pipeline requires mapping model predictions back to original context spans.
  • Using the HuggingFace run_qa.py example script.

Code Reference

Source Location

examples/NLU/examples/question-answering/trainer_qa.py (104 lines)

Signature

class QuestionAnsweringTrainer(Trainer):
    def __init__(
        self,
        *args,
        eval_examples=None,
        post_process_function=None,
        **kwargs
    ): ...

    def evaluate(
        self,
        eval_dataset=None,
        eval_examples=None,
        ignore_keys=None,
    ) -> dict: ...

    def predict(
        self,
        test_dataset,
        test_examples,
        ignore_keys=None,
    ) -> PredictionOutput: ...

Import / CLI Usage

from trainer_qa import QuestionAnsweringTrainer

trainer = QuestionAnsweringTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    eval_examples=eval_examples,
    tokenizer=tokenizer,
    data_collator=data_collator,
    post_process_function=post_processing_function,
    compute_metrics=compute_metrics,
)

I/O Contract

Inputs

Input Type Description
eval_examples Dataset Original QA examples (with context, question, answer fields) before feature creation
post_process_function callable Function that maps (examples, features, predictions) to answer spans. Signature: post_process_function(examples, dataset, predictions) -> EvalPrediction
eval_dataset Dataset Tokenized evaluation features (may have different column structure than examples)
test_dataset Dataset Tokenized test features
test_examples Dataset Original test examples

Outputs

Output Type Description
evaluate() return dict Metrics dictionary (e.g., {"exact_match": 80.5, "f1": 88.3})
predict() return PredictionOutput Namedtuple with predictions, label_ids, and metrics fields

Usage Examples

from transformers import AutoModelForQuestionAnswering, AutoTokenizer, TrainingArguments
from trainer_qa import QuestionAnsweringTrainer

model = AutoModelForQuestionAnswering.from_pretrained("bert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

training_args = TrainingArguments(
    output_dir="./qa_results",
    per_device_eval_batch_size=64,
    do_eval=True,
)

# post_processing_function maps raw start/end logits to text spans
# compute_metrics calculates exact_match and f1

trainer = QuestionAnsweringTrainer(
    model=model,
    args=training_args,
    eval_dataset=tokenized_eval_dataset,
    eval_examples=raw_eval_examples,
    tokenizer=tokenizer,
    post_process_function=post_processing_function,
    compute_metrics=compute_metrics,
)

# Evaluate
metrics = trainer.evaluate()
# {'exact_match': 80.5, 'f1': 88.3}

# Predict on test set
predictions = trainer.predict(
    test_dataset=tokenized_test_dataset,
    test_examples=raw_test_examples,
)
print(predictions.metrics)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment