Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Microsoft LoRA GPT2 Beam Search

From Leeroopedia


Overview

GPT2_Beam_Search implements beam search decoding for text generation using a LoRA-augmented GPT-2 model. The script loads a fine-tuned LoRA checkpoint, processes test data in batches across multiple GPUs, and writes JSONL output containing predicted token ID sequences for each input.

Type

API Doc

Source

  • examples/NLG/src/gpt2_beam.py (lines 206-392)

CLI Signature

python -m torch.distributed.launch --nproc_per_node=<N> src/gpt2_beam.py \
    --data <test_data> --init_checkpoint <lora_ckpt> \
    --lora_dim <r> --lora_alpha <alpha> \
    --beam <size> --length_penalty <lp> \
    --no_repeat_ngram_size <n> --eval_len <max_len> \
    --model_card gpt2.md

Argument reference:

Argument Type Default Description
--data str ../data/wikitext-103 Path to BPE-encoded test JSONL
--batch_size int 10 Batch size per GPU
--seq_len int 512 Input sequence length (context)
--eval_len int 256 Maximum generation length
--min_length int 0 Minimum generation length
--model_card str gpt2.sm Model size: gpt2.sm, gpt2.md, gpt2.lg
--init_checkpoint str None Path to LoRA checkpoint (.pt)
--lora_dim int 0 LoRA rank (must match training)
--lora_alpha int 128 LoRA scaling alpha
--beam int 1 Beam search width
--length_penalty float 1.0 Length normalization exponent
--no_repeat_ngram_size int 4 Ban repeated n-grams of this size
--repetition_penalty float 1.0 Token-level repetition penalty
--eos_token_id int (repeatable) [50256] EOS token ID(s)
--output_file str beam_prediction.jsonl Output filename
--work_dir str gpt2_model Working directory for output

Key Internal Function

beam(model, data_iter, args)

def beam(model, data_iter, args):

The core beam search loop. For each batch:

  1. Extracts _query (context token IDs) and _query_len from the dataset.
  2. Replicates each sample num_beams times for parallel beam processing.
  3. Runs the initial forward pass on the full context to obtain the first token logits and KV-cache (past).
  4. For each subsequent step (up to eval_len):
    • Computes logits for the last generated token using cached KV states.
    • Applies post-processing: no-repeat n-gram blocking, repetition penalty, minimum length enforcement.
    • Computes softmax and log-probabilities.
    • Selects top-B candidates per beam, reorders the KV cache using _reorder_cache.
    • Tracks completed beams via _add_beam_candidate.
  5. After all steps, selects the best sequence per batch item based on length-penalized score.
  6. Gathers results across GPUs via distributed_gather.

Supporting Functions

  • _reorder_cache(past, beam_idx) -- Reorders the KV cache tensors to match beam reassignments at each step.
  • _calc_banned_ngram_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len) -- Computes which tokens are banned for each hypothesis to prevent n-gram repetition.
  • _enforce_repetition_penalty_(lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty) -- Applies token-level repetition penalty to log-probabilities.
  • _postprocess_next_token_scores(scores, history, cur_len, ...) -- Orchestrates all score modifications (repetition penalty, n-gram blocking, minimum length).
  • _add_beam_candidate(best_score, best_sequence, ...) -- Updates the best hypothesis for each batch item when a beam terminates (hits EOS or reaches max length).

Input / Output

Direction Description
Input
  • BPE-encoded test JSONL (via FT_Dataset)
  • LoRA checkpoint file (model.<step>.pt)
Output JSONL file where each line is:
{"id": <int>, "predict": [<token_ids>...]}

Metadata

Field Value
Source microsoft/LoRA
Type API Doc
Last Updated 2026-02-10

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment