Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lm sys FastChat Peft Get Peft Model

From Leeroopedia


Knowledge Sources
Domains NLP, Training, Parameter-Efficient Fine-Tuning
Last Updated 2026-02-07 14:00 GMT

Overview

Wrapper around LoraConfig, prepare_model_for_kbit_training(), and get_peft_model() from the PEFT library, as used in FastChat's LoRA training script to inject LoRA adapters into a pretrained causal language model.

Description

In fastchat/train/train_lora.py, LoRA adapter injection occurs in three stages:

  1. A LoraConfig is constructed from the LoraArguments dataclass, specifying rank, alpha, target modules, dropout, bias handling, and task type.
  2. If q_lora=True, prepare_model_for_kbit_training() is called to freeze the base model, cast layer norms to float32, and optionally enable gradient checkpointing. For multi-GPU setups without DDP, the model is also marked as parallelizable.
  3. get_peft_model() wraps the base model with LoRA adapter layers, creating a PeftModel that routes gradients only to the adapter parameters.

After injection, when flash_attn is enabled, norm layers and embedding layers are explicitly cast to the compute dtype for compatibility. The model also calls enable_input_require_grads() when gradient checkpointing is active to ensure gradients flow through the frozen base model to the LoRA adapters.

Usage

Use this pattern when configuring and applying LoRA adapters to a pretrained model in the FastChat training pipeline.

Code Reference

Source Location

  • Repository: FastChat
  • File: fastchat/train/train_lora.py (lines 147-177, LoRA config and model wrapping)
  • File: fastchat/train/train_lora.py (lines 55-65, LoraArguments dataclass)

Signature

# LoraArguments dataclass (lines 55-65)
@dataclass
class LoraArguments:
    lora_r: int = 8
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    lora_target_modules: typing.List[str] = field(
        default_factory=lambda: ["q_proj", "v_proj"]
    )
    lora_weight_path: str = ""
    lora_bias: str = "none"
    q_lora: bool = False


# LoRA injection logic (lines 147-177)
lora_config = LoraConfig(
    r=lora_args.lora_r,
    lora_alpha=lora_args.lora_alpha,
    target_modules=lora_args.lora_target_modules,
    lora_dropout=lora_args.lora_dropout,
    bias=lora_args.lora_bias,
    task_type="CAUSAL_LM",
)

if lora_args.q_lora:
    model = prepare_model_for_kbit_training(
        model, use_gradient_checkpointing=training_args.gradient_checkpointing
    )
    if not ddp and torch.cuda.device_count() > 1:
        model.is_parallelizable = True
        model.model_parallel = True

model = get_peft_model(model, lora_config)

if training_args.flash_attn:
    for name, module in model.named_modules():
        if "norm" in name:
            module = module.to(compute_dtype)
        if "lm_head" in name or "embed_tokens" in name:
            if hasattr(module, "weight"):
                module = module.to(compute_dtype)

if training_args.gradient_checkpointing:
    model.enable_input_require_grads()

Import

from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

I/O Contract

Inputs

Name Type Required Description
model PreTrainedModel Yes Base causal language model (optionally quantized) to wrap with LoRA adapters
lora_args.lora_r int No Rank of the LoRA decomposition matrices; default: 8
lora_args.lora_alpha int No Scaling factor for LoRA; effective scaling is alpha / r; default: 16
lora_args.lora_target_modules List[str] No Names of linear modules to apply LoRA to; default: ["q_proj", "v_proj"]
lora_args.lora_dropout float No Dropout probability for LoRA layers; default: 0.05
lora_args.lora_bias str No Bias training strategy: "none", "all", or "lora_only"; default: "none"
lora_args.q_lora bool No Whether to prepare the model for k-bit training; default: False
training_args.gradient_checkpointing bool No Enable gradient checkpointing for memory savings; default: False
training_args.flash_attn bool No Enable FlashAttention monkey patch; default: False

Outputs

Name Type Description
model PeftModel Model wrapped with LoRA adapters; only adapter parameters are trainable

Usage Examples

Basic LoRA Injection

from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Output: trainable params: 4,194,304 || all params: 6,742,609,920 || trainable%: 0.06%

QLoRA Injection (with k-bit Preparation)

from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# Prepare quantized model for training
model = prepare_model_for_kbit_training(
    model, use_gradient_checkpointing=True
)

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)
model.enable_input_require_grads()

CLI Invocation with Custom LoRA Parameters

deepspeed fastchat/train/train_lora.py \
    --model_name_or_path meta-llama/Llama-2-13b-hf \
    --data_path data/dummy_conversation.json \
    --bf16 True \
    --output_dir output_lora \
    --lora_r 16 \
    --lora_alpha 32 \
    --lora_dropout 0.1 \
    --deepspeed playground/deepspeed_config_s2.json

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment