Implementation:Norrrrrrr lyn WAInjectBench JsonlImageDataset
| Knowledge Sources | |
|---|---|
| Domains | Data_Engineering, Computer_Vision, Deep_Learning |
| Last Updated | 2026-02-14 16:00 GMT |
Overview
Concrete PyTorch Dataset class for loading image-label pairs from JSONL files for LLaVA fine-tuning, provided by the WAInjectBench train/llava-ft module.
Description
The JsonlImageDataset class in train/llava-ft.py reads all lines from a JSONL file at initialization, storing them as a list of dictionaries. On each __getitem__ call, it opens the image via PIL.Image.open(path).convert("RGB") and returns a (PIL.Image, int) tuple. The companion collate function converts batches to (List[PIL.Image], Tensor[long]) format.
Usage
Instantiated for both --train_jsonl and --val_jsonl datasets in the LLaVA fine-tuning script.
Code Reference
Source Location
- Repository: WAInjectBench
- File: train/llava-ft.py (L25-44)
Signature
class JsonlImageDataset(Dataset):
def __init__(self, jsonl_path, transform=None):
with open(jsonl_path, "r", encoding="utf-8") as f:
self.items = [json.loads(l) for l in f if l.strip()]
self.transform = transform
def __len__(self):
return len(self.items)
def __getitem__(self, i):
p = self.items[i]["path"]
y = int(self.items[i]["label"])
img = Image.open(p).convert("RGB")
if self.transform:
img = self.transform(img)
return img, y
def collate(batch):
imgs, labels = zip(*batch)
return list(imgs), torch.tensor(labels, dtype=torch.long)
Import
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import json
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| jsonl_path | str | Yes | Path to JSONL file with {"path": str, "label": int} per line |
| transform | callable | No | Optional image transform (default None) |
Outputs
| Name | Type | Description |
|---|---|---|
| __getitem__ returns | Tuple[PIL.Image, int] | (image, label) pair |
| collate returns | Tuple[List[PIL.Image], Tensor] | Batched (images, labels) for model forward pass |
Usage Examples
Creating DataLoaders
from torch.utils.data import DataLoader
train_set = JsonlImageDataset("train_data.jsonl")
val_set = JsonlImageDataset("val_data.jsonl")
train_loader = DataLoader(
train_set, batch_size=8, shuffle=True,
num_workers=4, collate_fn=collate,
pin_memory=True, drop_last=False
)
val_loader = DataLoader(
val_set, batch_size=8, shuffle=False,
num_workers=2, collate_fn=collate,
pin_memory=True, drop_last=False
)
for imgs, labels in train_loader:
print(f"Batch: {len(imgs)} images, labels shape: {labels.shape}")
break