Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Microsoft LoRA GPT2LMModel Config

From Leeroopedia


Overview

GPT2LMModel_Config documents the API for constructing and initializing a GPT-2 language model with LoRA-augmented attention layers. The two key classes are GPT2Config (which holds all architecture and LoRA hyperparameters) and GPT2LMModel (which builds the transformer, injects LoRA via lora.MergedLinear, and loads pretrained weights).

Type

API Doc

Source

  • examples/NLG/src/model.py (lines 297-449)

Signatures

GPT2Config

class GPT2Config(object):
    def __init__(
        self,
        vocab_size_or_config_json_file=50257,
        n_positions=1024,
        n_ctx=1024,
        n_embd=768,
        n_layer=12,
        n_head=12,
        layer_norm_epsilon=1e-5,
        initializer_range=0.02,
        lora_attn_dim=0,
        lora_attn_alpha=128,
        lora_dropout=0.0,
        lora_r_dropout=0.0,
        fix_dropout=0.0,
    ):

Parameters:

Parameter Type Default Description
vocab_size_or_config_json_file int 50257 GPT-2 BPE vocabulary size
n_positions int 1024 Maximum sequence length for positional embeddings
n_ctx int 1024 Context window for causal attention mask
n_embd int 768 Hidden embedding dimension
n_layer int 12 Number of transformer blocks
n_head int 12 Number of attention heads
layer_norm_epsilon float 1e-5 Layer norm epsilon
initializer_range float 0.02 Std dev for weight initialization
lora_attn_dim int 0 LoRA rank r (0 = no LoRA)
lora_attn_alpha int 128 LoRA scaling alpha
lora_dropout float 0.0 LoRA input dropout
lora_r_dropout float 0.0 Additional LoRA dropout
fix_dropout float 0.0 Fixed dropout rate

GPT2LMModel

class GPT2LMModel(nn.Module):
    def __init__(self, config):

The constructor creates:

  • self.transformer = GPT2Model(config) -- The transformer backbone with LoRA-injected attention.
  • self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) -- Tied LM head.
  • Applies self._init_weights to all modules (normal distribution with std=0.02).

Key method -- load_weight:

def load_weight(self, state_dict):

Loads a pretrained checkpoint into the model. Handles key remapping for compatibility:

  • Keys ending in .g are renamed to .weight.
  • Keys ending in .b are renamed to .bias.
  • Keys ending in .w are renamed to .weight.
  • Keys prefixed with module.transformer. have the prefix stripped.
  • Any parameters in the model not found in the checkpoint (e.g., LoRA parameters) are initialized from the model's current state.
  • After loading, calls self.set_tied() to ensure embedding weight tying.

Key method -- forward:

def forward(
    self, input_ids, lm_labels=None, lm_mask=None,
    past=None, len_past=None, label_smooth=0.0,
    is_report_accuracy=False
):

Returns (lm_logits, loss) when lm_labels is provided, or (lm_logits, presents) for generation. Supports label smoothing via the label_smooth parameter.

Input / Output

Direction Description
Input Configuration parameters (via GPT2Config) + pretrained checkpoint file (pytorch_model.bin)
Output A GPT2LMModel instance with lora.MergedLinear modules in each attention layer's c_attn projection

Example

from model import GPT2Config, GPT2LMModel
import torch

# Configure GPT-2 Medium with LoRA rank 4
config = GPT2Config(
    n_embd=1024, n_layer=24, n_head=16,
    lora_attn_dim=4, lora_attn_alpha=32
)

# Build model and load pretrained weights
model = GPT2LMModel(config)
model.load_weight(torch.load('pretrained_checkpoints/gpt2-medium-pytorch_model.bin'))

# Move to GPU
model = model.cuda()

After construction, each of the 24 transformer blocks contains an Attention module whose c_attn is a lora.MergedLinear with enable_lora=[True, False, True], applying LoRA to the Q and V projections but not K.

Metadata

Field Value
Source microsoft/LoRA
Type API Doc
Last Updated 2026-02-10

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment