Implementation:Microsoft LoRA GPT2 Beam Search
Appearance
Overview
GPT2_Beam_Search implements beam search decoding for text generation using a LoRA-augmented GPT-2 model. The script loads a fine-tuned LoRA checkpoint, processes test data in batches across multiple GPUs, and writes JSONL output containing predicted token ID sequences for each input.
Type
API Doc
Source
examples/NLG/src/gpt2_beam.py(lines 206-392)
CLI Signature
python -m torch.distributed.launch --nproc_per_node=<N> src/gpt2_beam.py \
--data <test_data> --init_checkpoint <lora_ckpt> \
--lora_dim <r> --lora_alpha <alpha> \
--beam <size> --length_penalty <lp> \
--no_repeat_ngram_size <n> --eval_len <max_len> \
--model_card gpt2.md
Argument reference:
| Argument | Type | Default | Description |
|---|---|---|---|
--data |
str | ../data/wikitext-103 | Path to BPE-encoded test JSONL |
--batch_size |
int | 10 | Batch size per GPU |
--seq_len |
int | 512 | Input sequence length (context) |
--eval_len |
int | 256 | Maximum generation length |
--min_length |
int | 0 | Minimum generation length |
--model_card |
str | gpt2.sm | Model size: gpt2.sm, gpt2.md, gpt2.lg |
--init_checkpoint |
str | None | Path to LoRA checkpoint (.pt) |
--lora_dim |
int | 0 | LoRA rank (must match training) |
--lora_alpha |
int | 128 | LoRA scaling alpha |
--beam |
int | 1 | Beam search width |
--length_penalty |
float | 1.0 | Length normalization exponent |
--no_repeat_ngram_size |
int | 4 | Ban repeated n-grams of this size |
--repetition_penalty |
float | 1.0 | Token-level repetition penalty |
--eos_token_id |
int (repeatable) | [50256] | EOS token ID(s) |
--output_file |
str | beam_prediction.jsonl | Output filename |
--work_dir |
str | gpt2_model | Working directory for output |
Key Internal Function
beam(model, data_iter, args)
def beam(model, data_iter, args):
The core beam search loop. For each batch:
- Extracts
_query(context token IDs) and_query_lenfrom the dataset. - Replicates each sample
num_beamstimes for parallel beam processing. - Runs the initial forward pass on the full context to obtain the first token logits and KV-cache (
past). - For each subsequent step (up to
eval_len):- Computes logits for the last generated token using cached KV states.
- Applies post-processing: no-repeat n-gram blocking, repetition penalty, minimum length enforcement.
- Computes softmax and log-probabilities.
- Selects top-B candidates per beam, reorders the KV cache using
_reorder_cache. - Tracks completed beams via
_add_beam_candidate.
- After all steps, selects the best sequence per batch item based on length-penalized score.
- Gathers results across GPUs via
distributed_gather.
Supporting Functions
- _reorder_cache(past, beam_idx) -- Reorders the KV cache tensors to match beam reassignments at each step.
- _calc_banned_ngram_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len) -- Computes which tokens are banned for each hypothesis to prevent n-gram repetition.
- _enforce_repetition_penalty_(lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty) -- Applies token-level repetition penalty to log-probabilities.
- _postprocess_next_token_scores(scores, history, cur_len, ...) -- Orchestrates all score modifications (repetition penalty, n-gram blocking, minimum length).
- _add_beam_candidate(best_score, best_sequence, ...) -- Updates the best hypothesis for each batch item when a beam terminates (hits EOS or reaches max length).
Input / Output
| Direction | Description |
|---|---|
| Input |
|
| Output | JSONL file where each line is: {"id": <int>, "predict": [<token_ids>...]}
|
Metadata
| Field | Value |
|---|---|
| Source | microsoft/LoRA |
| Type | API Doc |
| Last Updated | 2026-02-10 |
Related
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment