Implementation:Norrrrrrr lyn WAInjectBench LlavaYesnoToken Init
| Knowledge Sources | |
|---|---|
| Domains | Computer_Vision, NLP, Deep_Learning |
| Last Updated | 2026-02-14 16:00 GMT |
Overview
Concrete tool for initializing a LLaVA model constrained to binary Yes/No classification, provided by the WAInjectBench train/llava_yesno_token module.
Description
The LlavaYesnoToken class in train/llava_yesno_token.py wraps LlavaForConditionalGeneration from HuggingFace Transformers. The __init__ method loads the model with specified dtype, enables gradient checkpointing, initializes the processor/tokenizer, and resolves Yes/No token IDs using the get_yes_no_ids helper. The forward method processes image batches, handles device placement for vision tower vs language model, and returns binary logits of shape [B, 2] where column 0 is "No" and column 1 is "Yes".
Usage
Instantiated in train/llava-ft.py:L248-252 with the base model ID, dtype, and CUDA availability flag.
Code Reference
Source Location
- Repository: WAInjectBench
- File: train/llava_yesno_token.py (L25-108)
Signature
class LlavaYesnoToken(nn.Module):
def __init__(self, base_model_id: str, dtype: torch.dtype, use_cuda: bool = True):
super().__init__()
self.model = LlavaForConditionalGeneration.from_pretrained(
base_model_id,
torch_dtype=dtype,
device_map=None,
use_safetensors=True,
)
try:
self.model.gradient_checkpointing_enable()
except Exception:
pass
self.processor = AutoProcessor.from_pretrained(base_model_id, use_fast=True)
self.tokenizer = self.processor.tokenizer
self.ID_YES, self.ID_NO, self.VERB_YES, self.VERB_NO = get_yes_no_ids(self.tokenizer)
def forward(self, images, sys_prompt: str) -> torch.Tensor:
# Processes images + prompt, returns [B, 2] logits for [No, Yes]
...
next_logits = out.logits[:, -1, :] # [B, vocab]
two = next_logits[:, [self.ID_NO, self.ID_YES]] # 0:No, 1:Yes
return two
Import
from llava_yesno_token import LlavaYesnoToken
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| base_model_id | str | Yes | HuggingFace model ID (default "llava-hf/llava-1.5-7b-hf") |
| dtype | torch.dtype | Yes | Model precision (torch.bfloat16, torch.float16, or torch.float32) |
| use_cuda | bool | No | Whether CUDA is available (default True) |
Outputs
| Name | Type | Description |
|---|---|---|
| model | LlavaYesnoToken | nn.Module wrapping LLaVA with Yes/No constrained output |
| forward() returns | Tensor[B, 2] | Binary logits: column 0 = No probability, column 1 = Yes probability |
Usage Examples
Initializing for Fine-tuning
import torch
from llava_yesno_token import LlavaYesnoToken
model = LlavaYesnoToken(
base_model_id="llava-hf/llava-1.5-7b-hf",
dtype=torch.bfloat16,
use_cuda=True
)
# Forward pass with images
from PIL import Image
imgs = [Image.open("test.png").convert("RGB")]
sys_prompt = "Decide whether the image contains a prompt injection."
logits = model(imgs, sys_prompt=sys_prompt)
predictions = logits.argmax(-1) # 0=No, 1=Yes