Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Norrrrrrr lyn WAInjectBench try wrap lora

From Leeroopedia
Knowledge Sources
Domains Deep_Learning, Parameter_Efficient_Finetuning
Last Updated 2026-02-14 16:00 GMT

Overview

Concrete tool for injecting LoRA adapters into the LLaVA model's sub-modules, provided by the WAInjectBench train/llava-ft module using the PEFT library.

Description

The try_wrap_lora function in train/llava-ft.py wraps model.model (the inner LlavaForConditionalGeneration) with PEFT's get_peft_model. It applies a LoraConfig with the specified rank, alpha, and dropout to a comprehensive list of 13 target modules. After wrapping, it freezes all base parameters and enables only lora_* parameters for gradient computation. The function includes a graceful fallback: if PEFT is not installed, it prints a warning and returns the unmodified model.

Usage

Called when --use_lora flag is set in the LLaVA fine-tuning CLI. Default hyperparameters: lora_r=8, lora_alpha=32, lora_dropout=0.05.

Code Reference

Source Location

Signature

def try_wrap_lora(model: nn.Module, lora_r: int, lora_alpha: int, lora_dropout: float):
    try:
        from peft import LoraConfig, get_peft_model, TaskType
    except Exception as e:
        print(f"[WARN] peft import failed: {e}. Training without LoRA.")
        return model

    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
        "fc1", "fc2", "Wqkv", "out_proj", "proj", "dense"
    ]

    cfg = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        bias="none",
        task_type=TaskType.CAUSAL_LM,
        target_modules=target_modules,
    )
    model.model = get_peft_model(model.model, cfg)

    # Freeze all, then enable only lora_* parameters
    for name, param in model.model.named_parameters():
        param.requires_grad = False
    for name, param in model.model.named_parameters():
        if "lora_" in name.lower():
            param.requires_grad = True

    if hasattr(model.model, "enable_input_require_grads"):
        model.model.enable_input_require_grads()

    return model

Import

from peft import LoraConfig, get_peft_model, TaskType

I/O Contract

Inputs

Name Type Required Description
model nn.Module Yes LlavaYesnoToken model instance
lora_r int Yes LoRA rank (default 8)
lora_alpha int Yes LoRA scaling factor (default 32)
lora_dropout float Yes LoRA dropout rate (default 0.05)

Outputs

Name Type Description
model nn.Module Model with LoRA adapters injected into model.model; only lora_* params trainable

Usage Examples

Injecting LoRA Adapters

from llava_yesno_token import LlavaYesnoToken

model = LlavaYesnoToken("llava-hf/llava-1.5-7b-hf", dtype=torch.bfloat16)

# Inject LoRA with default settings
model = try_wrap_lora(model, lora_r=8, lora_alpha=32, lora_dropout=0.05)

# Check trainable parameters
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable} / {total} ({100*trainable/total:.4f}%)")

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment