Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft LoRA Run QA Beam Search

From Leeroopedia


Template:Implementation metadata

Overview

run_qa_beam_search.py is a specialized QA fine-tuning script for XLNet that uses beam search decoding with XLNetForQuestionAnswering and the postprocess_qa_predictions_with_beam_search post-processor.

Description

This script is specifically designed for XLNet-based question answering, which differs from standard extractive QA models. XLNet uses a different output format: instead of simple start/end logits, it produces top-K start log probabilities and indices, top-K end log probabilities and indices (conditioned on start positions), and CLS logits for answerability classification. The beam search post-processing combines these to find the best answer spans.

Key differences from run_qa.py:

  • Model-specific imports: Uses XLNetConfig, XLNetForQuestionAnswering, and XLNetTokenizerFast instead of Auto classes.
  • Training features: Includes additional fields: is_impossible (float), cls_index (CLS token position), and p_mask (probability mask indicating which tokens can be in answers). The p_mask assigns 0.0 to context tokens and 1.0 to non-context/special tokens.
  • Validation features: Similarly includes cls_index and p_mask for beam search decoding.
  • Five-element predictions: The model returns (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits).
  • Beam search post-processing: Uses postprocess_qa_predictions_with_beam_search() from utils_qa with configurable start_n_top and end_n_top parameters from the model config.
  • Special token handling: Builds sequence_ids from token_type_ids, marking special tokens with value 3 based on the special_tokens_mask.

The script also enforces check_min_version("4.4.0") and uses padding="max_length" always (not dynamically).

Usage

Use this script when you need to:

  • Fine-tune XLNet specifically on extractive QA tasks
  • Leverage beam search decoding for higher-quality answer extraction
  • Handle SQuAD v2 unanswerable questions with XLNet's CLS logits

Code Reference

Source Location

Property Value
File examples/NLU/examples/question-answering/run_qa_beam_search.py
Lines 590
Module run_qa_beam_search
Entry Point main()
Dependencies trainer_qa.QuestionAnsweringTrainer, utils_qa.postprocess_qa_predictions_with_beam_search

Signature/CLI

python run_qa_beam_search.py \
    --model_name_or_path MODEL_NAME \
    --dataset_name DATASET_NAME \
    --output_dir OUTPUT_DIR \
    --do_train \
    --do_eval \
    [--dataset_config_name CONFIG] \
    [--train_file TRAIN_FILE] \
    [--validation_file VALIDATION_FILE] \
    [--max_seq_length 384] \
    [--doc_stride 128] \
    [--n_best_size 20] \
    [--max_answer_length 30] \
    [--version_2_with_negative] \
    [--null_score_diff_threshold 0.0]

Import

from transformers import (
    DataCollatorWithPadding,
    EvalPrediction,
    HfArgumentParser,
    TrainingArguments,
    XLNetConfig,
    XLNetForQuestionAnswering,
    XLNetTokenizerFast,
    default_data_collator,
    set_seed,
)
from trainer_qa import QuestionAnsweringTrainer
from utils_qa import postprocess_qa_predictions_with_beam_search

I/O Contract

Inputs

Parameter Type Required Default Description
--model_name_or_path str Yes - XLNet pretrained model name or path
--output_dir str Yes - Directory for checkpoints and results
--dataset_name str No None HuggingFace dataset name (e.g., squad)
--train_file str No None Custom CSV/JSON training file
--validation_file str No None Custom CSV/JSON validation file
--max_seq_length int No 384 Max tokenized sequence length
--doc_stride int No 128 Stride for sliding window over long documents
--n_best_size int No 20 Number of n-best predictions for beam search
--max_answer_length int No 30 Maximum answer span length
--version_2_with_negative flag No False Enable unanswerable question support

Outputs

Output Location Description
Trained model {output_dir}/ Saved XLNet model, config, and tokenizer
Predictions {output_dir}/predictions.json Example ID to answer text mapping
N-best predictions {output_dir}/nbest_predictions.json Top-N predictions with beam search scores
Null odds {output_dir}/null_odds.json CLS logit scores for unanswerable classification
Evaluation metrics {output_dir}/eval_results.json Exact match and F1 scores

Usage Examples

Fine-tune XLNet on SQuAD

python examples/NLU/examples/question-answering/run_qa_beam_search.py \
    --model_name_or_path xlnet-base-cased \
    --dataset_name squad \
    --do_train \
    --do_eval \
    --per_device_train_batch_size 8 \
    --learning_rate 3e-5 \
    --num_train_epochs 2 \
    --max_seq_length 384 \
    --doc_stride 128 \
    --output_dir /tmp/xlnet_squad_output

Fine-tune XLNet on SQuAD v2

python examples/NLU/examples/question-answering/run_qa_beam_search.py \
    --model_name_or_path xlnet-large-cased \
    --dataset_name squad_v2 \
    --do_train \
    --do_eval \
    --version_2_with_negative \
    --per_device_train_batch_size 4 \
    --learning_rate 2e-5 \
    --num_train_epochs 3 \
    --output_dir /tmp/xlnet_squad_v2_output

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment