Implementation:Microsoft LoRA Run QA Beam Search
Template:Implementation metadata
Overview
run_qa_beam_search.py is a specialized QA fine-tuning script for XLNet that uses beam search decoding with XLNetForQuestionAnswering and the postprocess_qa_predictions_with_beam_search post-processor.
Description
This script is specifically designed for XLNet-based question answering, which differs from standard extractive QA models. XLNet uses a different output format: instead of simple start/end logits, it produces top-K start log probabilities and indices, top-K end log probabilities and indices (conditioned on start positions), and CLS logits for answerability classification. The beam search post-processing combines these to find the best answer spans.
Key differences from run_qa.py:
- Model-specific imports: Uses
XLNetConfig,XLNetForQuestionAnswering, andXLNetTokenizerFastinstead of Auto classes. - Training features: Includes additional fields:
is_impossible(float),cls_index(CLS token position), andp_mask(probability mask indicating which tokens can be in answers). The p_mask assigns 0.0 to context tokens and 1.0 to non-context/special tokens. - Validation features: Similarly includes
cls_indexandp_maskfor beam search decoding. - Five-element predictions: The model returns
(start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits). - Beam search post-processing: Uses
postprocess_qa_predictions_with_beam_search()fromutils_qawith configurablestart_n_topandend_n_topparameters from the model config. - Special token handling: Builds
sequence_idsfromtoken_type_ids, marking special tokens with value 3 based on thespecial_tokens_mask.
The script also enforces check_min_version("4.4.0") and uses padding="max_length" always (not dynamically).
Usage
Use this script when you need to:
- Fine-tune XLNet specifically on extractive QA tasks
- Leverage beam search decoding for higher-quality answer extraction
- Handle SQuAD v2 unanswerable questions with XLNet's CLS logits
Code Reference
Source Location
| Property | Value |
|---|---|
| File | examples/NLU/examples/question-answering/run_qa_beam_search.py
|
| Lines | 590 |
| Module | run_qa_beam_search
|
| Entry Point | main()
|
| Dependencies | trainer_qa.QuestionAnsweringTrainer, utils_qa.postprocess_qa_predictions_with_beam_search
|
Signature/CLI
python run_qa_beam_search.py \
--model_name_or_path MODEL_NAME \
--dataset_name DATASET_NAME \
--output_dir OUTPUT_DIR \
--do_train \
--do_eval \
[--dataset_config_name CONFIG] \
[--train_file TRAIN_FILE] \
[--validation_file VALIDATION_FILE] \
[--max_seq_length 384] \
[--doc_stride 128] \
[--n_best_size 20] \
[--max_answer_length 30] \
[--version_2_with_negative] \
[--null_score_diff_threshold 0.0]
Import
from transformers import (
DataCollatorWithPadding,
EvalPrediction,
HfArgumentParser,
TrainingArguments,
XLNetConfig,
XLNetForQuestionAnswering,
XLNetTokenizerFast,
default_data_collator,
set_seed,
)
from trainer_qa import QuestionAnsweringTrainer
from utils_qa import postprocess_qa_predictions_with_beam_search
I/O Contract
Inputs
| Parameter | Type | Required | Default | Description |
|---|---|---|---|---|
--model_name_or_path |
str | Yes | - | XLNet pretrained model name or path |
--output_dir |
str | Yes | - | Directory for checkpoints and results |
--dataset_name |
str | No | None | HuggingFace dataset name (e.g., squad)
|
--train_file |
str | No | None | Custom CSV/JSON training file |
--validation_file |
str | No | None | Custom CSV/JSON validation file |
--max_seq_length |
int | No | 384 | Max tokenized sequence length |
--doc_stride |
int | No | 128 | Stride for sliding window over long documents |
--n_best_size |
int | No | 20 | Number of n-best predictions for beam search |
--max_answer_length |
int | No | 30 | Maximum answer span length |
--version_2_with_negative |
flag | No | False | Enable unanswerable question support |
Outputs
| Output | Location | Description |
|---|---|---|
| Trained model | {output_dir}/ |
Saved XLNet model, config, and tokenizer |
| Predictions | {output_dir}/predictions.json |
Example ID to answer text mapping |
| N-best predictions | {output_dir}/nbest_predictions.json |
Top-N predictions with beam search scores |
| Null odds | {output_dir}/null_odds.json |
CLS logit scores for unanswerable classification |
| Evaluation metrics | {output_dir}/eval_results.json |
Exact match and F1 scores |
Usage Examples
Fine-tune XLNet on SQuAD
python examples/NLU/examples/question-answering/run_qa_beam_search.py \
--model_name_or_path xlnet-base-cased \
--dataset_name squad \
--do_train \
--do_eval \
--per_device_train_batch_size 8 \
--learning_rate 3e-5 \
--num_train_epochs 2 \
--max_seq_length 384 \
--doc_stride 128 \
--output_dir /tmp/xlnet_squad_output
Fine-tune XLNet on SQuAD v2
python examples/NLU/examples/question-answering/run_qa_beam_search.py \
--model_name_or_path xlnet-large-cased \
--dataset_name squad_v2 \
--do_train \
--do_eval \
--version_2_with_negative \
--per_device_train_batch_size 4 \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--output_dir /tmp/xlnet_squad_v2_output
Related Pages
- Environment:Microsoft_LoRA_NLU_Conda_Environment
- Implementation:Microsoft_LoRA_Run_QA - General extractive QA with AutoModel
- Implementation:Microsoft_LoRA_Utils_QA - Post-processing utilities including beam search decoding