Implementation:Norrrrrrr lyn WAInjectBench AdamW Training Loop
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, Optimization |
| Last Updated | 2026-02-14 16:00 GMT |
Overview
Concrete training loop implementation using AdamW, CrossEntropyLoss, GradScaler, and cosine LR scheduling, provided by the WAInjectBench train/llava-ft module.
Description
The training loop in train/llava-ft.py (L284-371) combines several PyTorch components:
- Optimizer:
torch.optim.AdamWwith configurable LR (default 2e-5) and weight decay (default 0.0) - Loss:
nn.CrossEntropyLoss()on [B, 2] logits vs long labels - AMP:
torch.amp.GradScaler+torch.amp.autocastfor bf16/fp16 - LR Schedule:
LambdaLRwith linear warmup (default 3% of steps) then cosine decay - Gradient Clipping:
clip_grad_norm_with max_norm=1.0 - NaN Fallback:
maybe_fallback_to_fp32disables AMP and applies LR backoff on NaN/Inf detection
Usage
Executed as the main training loop after model initialization, LoRA injection, and device placement.
Code Reference
Source Location
- Repository: WAInjectBench
- File: train/llava-ft.py (L284-373)
Signature
# Optimizer setup (L284)
optim = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# LR scheduler with cosine warmup (L302-308)
def lr_lambda(current_step):
if current_step < warmup_steps:
return float(current_step) / float(max(1, warmup_steps))
progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda)
# Training loop (L334-373)
for epoch in range(1, args.epochs + 1):
model.train()
for imgs, labels in train_loader:
labels = labels.to(run_device)
optim.zero_grad(set_to_none=True)
with get_autocast_context(state):
logits = model(imgs, sys_prompt=SYSTEM_PROMPT)
loss = criterion(logits.to(labels.device), labels)
# Handle NaN/Inf
if torch.isnan(loss) or torch.isinf(loss):
maybe_fallback_to_fp32(model, optim, state, args.lr_backoff)
continue
# AMP-aware backward
if state.use_amp and state.scaler is not None:
state.scaler.scale(loss).backward()
state.scaler.unscale_(optim)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
state.scaler.step(optim)
state.scaler.update()
else:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optim.step()
lr_scheduler.step()
Import
import torch
import torch.nn as nn
import math
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | nn.Module | Yes | LoRA-wrapped LlavaYesnoToken model on device |
| train_loader | DataLoader | Yes | Training data batches of (List[PIL.Image], Tensor[long]) |
| lr | float | No | Learning rate (default 2e-5) |
| weight_decay | float | No | Weight decay (default 0.0) |
| grad_clip | float | No | Gradient norm clipping (default 1.0) |
| epochs | int | No | Number of epochs (default 3) |
| warmup_ratio | float | No | Fraction of steps for warmup (default 0.03) |
| amp_dtype | str | No | Mixed precision type: "bf16", "fp16", or "fp32" (default "bf16") |
Outputs
| Name | Type | Description |
|---|---|---|
| Updated model | nn.Module | Model with LoRA weights updated by training |
| avg_train_loss | float | Average training loss per epoch |
Usage Examples
Running the Training Loop
# Setup (simplified)
optim = torch.optim.AdamW(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()
scaler = torch.amp.GradScaler("cuda", enabled=True)
for epoch in range(1, 4):
model.train()
for imgs, labels in train_loader:
labels = labels.to(device)
optim.zero_grad(set_to_none=True)
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
logits = model(imgs, sys_prompt="Detect prompt injection.")
loss = criterion(logits.to(labels.device), labels)
scaler.scale(loss).backward()
scaler.unscale_(optim)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optim)
scaler.update()
Related Pages
Implements Principle
Requires Environment
Uses Heuristic
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment