Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Norrrrrrr lyn WAInjectBench AdamW Training Loop

From Leeroopedia
Knowledge Sources
Domains Deep_Learning, Optimization
Last Updated 2026-02-14 16:00 GMT

Overview

Concrete training loop implementation using AdamW, CrossEntropyLoss, GradScaler, and cosine LR scheduling, provided by the WAInjectBench train/llava-ft module.

Description

The training loop in train/llava-ft.py (L284-371) combines several PyTorch components:

  • Optimizer: torch.optim.AdamW with configurable LR (default 2e-5) and weight decay (default 0.0)
  • Loss: nn.CrossEntropyLoss() on [B, 2] logits vs long labels
  • AMP: torch.amp.GradScaler + torch.amp.autocast for bf16/fp16
  • LR Schedule: LambdaLR with linear warmup (default 3% of steps) then cosine decay
  • Gradient Clipping: clip_grad_norm_ with max_norm=1.0
  • NaN Fallback: maybe_fallback_to_fp32 disables AMP and applies LR backoff on NaN/Inf detection

Usage

Executed as the main training loop after model initialization, LoRA injection, and device placement.

Code Reference

Source Location

Signature

# Optimizer setup (L284)
optim = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

# LR scheduler with cosine warmup (L302-308)
def lr_lambda(current_step):
    if current_step < warmup_steps:
        return float(current_step) / float(max(1, warmup_steps))
    progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
    return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda)

# Training loop (L334-373)
for epoch in range(1, args.epochs + 1):
    model.train()
    for imgs, labels in train_loader:
        labels = labels.to(run_device)
        optim.zero_grad(set_to_none=True)
        with get_autocast_context(state):
            logits = model(imgs, sys_prompt=SYSTEM_PROMPT)
            loss = criterion(logits.to(labels.device), labels)
        # Handle NaN/Inf
        if torch.isnan(loss) or torch.isinf(loss):
            maybe_fallback_to_fp32(model, optim, state, args.lr_backoff)
            continue
        # AMP-aware backward
        if state.use_amp and state.scaler is not None:
            state.scaler.scale(loss).backward()
            state.scaler.unscale_(optim)
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            state.scaler.step(optim)
            state.scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            optim.step()
        lr_scheduler.step()

Import

import torch
import torch.nn as nn
import math

I/O Contract

Inputs

Name Type Required Description
model nn.Module Yes LoRA-wrapped LlavaYesnoToken model on device
train_loader DataLoader Yes Training data batches of (List[PIL.Image], Tensor[long])
lr float No Learning rate (default 2e-5)
weight_decay float No Weight decay (default 0.0)
grad_clip float No Gradient norm clipping (default 1.0)
epochs int No Number of epochs (default 3)
warmup_ratio float No Fraction of steps for warmup (default 0.03)
amp_dtype str No Mixed precision type: "bf16", "fp16", or "fp32" (default "bf16")

Outputs

Name Type Description
Updated model nn.Module Model with LoRA weights updated by training
avg_train_loss float Average training loss per epoch

Usage Examples

Running the Training Loop

# Setup (simplified)
optim = torch.optim.AdamW(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()
scaler = torch.amp.GradScaler("cuda", enabled=True)

for epoch in range(1, 4):
    model.train()
    for imgs, labels in train_loader:
        labels = labels.to(device)
        optim.zero_grad(set_to_none=True)
        with torch.amp.autocast("cuda", dtype=torch.bfloat16):
            logits = model(imgs, sys_prompt="Detect prompt injection.")
            loss = criterion(logits.to(labels.device), labels)
        scaler.scale(loss).backward()
        scaler.unscale_(optim)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optim)
        scaler.update()

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment