Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Microsoft LoRA GPT2 Finetuning Script

From Leeroopedia


Overview

GPT2_Finetuning_Script is the main training entry point for fine-tuning GPT-2 with LoRA on NLG datasets. It orchestrates distributed data loading, model construction, optimizer setup, training loops with gradient accumulation, periodic evaluation, and LoRA-only checkpoint saving.

Type

API Doc

Source

  • examples/NLG/src/gpt2_ft.py (lines 171-361)
  • examples/NLG/src/data_utils.py (lines 198-269) -- FT_Dataset
  • examples/NLG/src/optimizer.py (lines 316-328) -- create_adam_optimizer_from_args

CLI Signature

python -m torch.distributed.launch --nproc_per_node=<N> src/gpt2_ft.py \
    --train_data <path> --valid_data <path> \
    --lora_dim <r> --lora_alpha <alpha> \
    --init_checkpoint <ckpt> --model_card gpt2.md \
    --lr <lr> --train_batch_size <bs> --seq_len <len> \
    --max_epoch <epochs> --scheduler linear \
    --warmup_step <steps> --label_smooth 0.1

Full argument reference:

Argument Type Default Description
--train_data str required Path to BPE-encoded training JSONL
--valid_data str required Path to BPE-encoded validation JSONL
--train_batch_size int 8 Training batch size per GPU
--valid_batch_size int 4 Validation batch size per GPU
--grad_acc int 1 Gradient accumulation steps
--clip float 0.0 Gradient clipping norm (0 = disabled)
--seq_len int 512 Maximum sequence length
--model_card str gpt2.md Model size preset: gpt2.sm, gpt2.md, gpt2.lg
--init_checkpoint str None Path to pretrained checkpoint
--fp16 flag False Enable FP16 mixed precision via Apex
--lora_dim int 0 LoRA rank r (0 = no LoRA)
--lora_alpha int 128 LoRA scaling alpha
--lora_dropout float 0.0 LoRA dropout probability
--label_smooth float 0.0 Label smoothing coefficient
--lr float 0.00001 Learning rate
--weight_decay float 0.01 AdamW weight decay
--scheduler str linear LR scheduler: linear, cosine, cycle, constant
--warmup_step int 0 Number of warmup steps
--max_epoch int None Maximum training epochs
--max_step int None Maximum training steps (auto-computed if None)
--save_interval int 500 Save checkpoint every N steps
--eval_interval int 2000 Evaluate every N steps
--obj str clm Training objective: clm (causal LM) or jlm (joint LM)
--work_dir str gpt2_model Output directory for checkpoints and logs

Key Internal Components

FT_Dataset (data_utils.py:198-269)

class FT_Dataset(Dataset):
    def __init__(self, ft_file, batch_size, max_seq_length,
                 max_eval_length=0, joint_lm=False, prefix_len=0, infix_len=0,
                 prefix_cursor=1000000, infix_cursor=2000000):

Reads BPE-encoded JSONL, pads context and completion to max_seq_length, and constructs input/target/mask tensors. The mask is all zeros for the context region and all ones for the completion region (under CLM objective), ensuring the loss is computed only on the generated portion.

create_adam_optimizer_from_args (optimizer.py:316-328)

def create_adam_optimizer_from_args(model, args, grouped_parameters=None):

Creates an AdamW optimizer with parameters from the command-line arguments. Supports optional weight decay exclusion for bias and layer norm weights when --no_decay_bias is set.

train_validate (gpt2_ft.py:171-258)

def train_validate(model, optimizer, scheduler, train_loader, valid_loader,
                   args, train_step=0, epoch=0):

The core training loop. For each batch:

  1. Computes forward pass with label smoothing.
  2. Accumulates gradients (divides loss by args.grad_acc).
  3. Performs optimizer step when accumulation count is reached.
  4. Logs training metrics at log_interval.
  5. Saves LoRA-only checkpoint at save_interval.
  6. Runs validation at eval_interval.

Input / Output

Direction Description
Input
  • BPE-encoded JSONL files for training and validation
  • Pretrained GPT-2 checkpoint (pytorch_model.bin)
Output
  • LoRA checkpoint files at save intervals: torch.save({'model_state_dict': lora.lora_state_dict(model)}, model_path)
  • Final epoch checkpoint with full model state dict
  • Checkpoint naming: model.<step>.pt in the work directory

Metadata

Field Value
Source microsoft/LoRA
Type API Doc
Last Updated 2026-02-10

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment