Implementation:Microsoft LoRA QA Trainer
Overview
The trainer_qa.py module provides a QuestionAnsweringTrainer class that extends the HuggingFace Trainer with custom evaluation and prediction methods tailored for extractive question answering tasks.
Description
Question answering tasks require special post-processing: raw model logits (start/end position scores) must be mapped back to text spans in the original context, which may have been split into overlapping features. The standard Trainer.evaluate() method does not support this two-stage pipeline. QuestionAnsweringTrainer addresses this by:
Constructor:
- Accepts
eval_examples(the original QA examples before feature creation) andpost_process_function(a callable that converts raw predictions back to answer spans).
evaluate() method:
- Temporarily disables
compute_metricsto prevent premature metric computation during the prediction loop. - Runs the standard
prediction_loop()to get raw predictions. - Restores dataset columns that may have been removed during preprocessing (using
datasets.Dataset.set_format()). - Calls
post_process_function(eval_examples, eval_dataset, output.predictions)to convert raw logits to answer spans. - Calls
compute_metrics(eval_preds)to calculate QA metrics (typically exact match and F1). - Handles TPU metrics debugging via
xm.master_print(met.metrics_report())when enabled. - Fires the
on_evaluatecallback.
predict() method:
- Follows the same pattern as
evaluate()but for test datasets. - Returns a
PredictionOutputnamedtuple with predictions, label IDs, and metrics.
Usage
Use this trainer when:
- Fine-tuning models for extractive question answering (e.g., SQuAD, SQuAD v2).
- The evaluation pipeline requires mapping model predictions back to original context spans.
- Using the HuggingFace
run_qa.pyexample script.
Code Reference
Source Location
examples/NLU/examples/question-answering/trainer_qa.py (104 lines)
Signature
class QuestionAnsweringTrainer(Trainer):
def __init__(
self,
*args,
eval_examples=None,
post_process_function=None,
**kwargs
): ...
def evaluate(
self,
eval_dataset=None,
eval_examples=None,
ignore_keys=None,
) -> dict: ...
def predict(
self,
test_dataset,
test_examples,
ignore_keys=None,
) -> PredictionOutput: ...
Import / CLI Usage
from trainer_qa import QuestionAnsweringTrainer
trainer = QuestionAnsweringTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
eval_examples=eval_examples,
tokenizer=tokenizer,
data_collator=data_collator,
post_process_function=post_processing_function,
compute_metrics=compute_metrics,
)
I/O Contract
Inputs
| Input | Type | Description |
|---|---|---|
eval_examples |
Dataset | Original QA examples (with context, question, answer fields) before feature creation |
post_process_function |
callable | Function that maps (examples, features, predictions) to answer spans. Signature: post_process_function(examples, dataset, predictions) -> EvalPrediction
|
eval_dataset |
Dataset | Tokenized evaluation features (may have different column structure than examples) |
test_dataset |
Dataset | Tokenized test features |
test_examples |
Dataset | Original test examples |
Outputs
| Output | Type | Description |
|---|---|---|
evaluate() return |
dict | Metrics dictionary (e.g., {"exact_match": 80.5, "f1": 88.3})
|
predict() return |
PredictionOutput | Namedtuple with predictions, label_ids, and metrics fields
|
Usage Examples
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, TrainingArguments
from trainer_qa import QuestionAnsweringTrainer
model = AutoModelForQuestionAnswering.from_pretrained("bert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
training_args = TrainingArguments(
output_dir="./qa_results",
per_device_eval_batch_size=64,
do_eval=True,
)
# post_processing_function maps raw start/end logits to text spans
# compute_metrics calculates exact_match and f1
trainer = QuestionAnsweringTrainer(
model=model,
args=training_args,
eval_dataset=tokenized_eval_dataset,
eval_examples=raw_eval_examples,
tokenizer=tokenizer,
post_process_function=post_processing_function,
compute_metrics=compute_metrics,
)
# Evaluate
metrics = trainer.evaluate()
# {'exact_match': 80.5, 'f1': 88.3}
# Predict on test set
predictions = trainer.predict(
test_dataset=tokenized_test_dataset,
test_examples=raw_test_examples,
)
print(predictions.metrics)