Implementation:Mlfoundations Open flamingo EvalModel
Appearance
Overview
Concrete tool providing the OpenFlamingo evaluation model wrapper with KV-cache classification and task-specific prompt formatting provided by the OpenFlamingo evaluation module.
Description
EvalModel extends BaseEvalModel to wrap a Flamingo model for evaluation. Key features:
- Loads model via create_model_and_transforms + checkpoint
- get_outputs() generates text from batched image-text inputs
- get_rank_classifications() uses KV-cache to efficiently score multiple class names
- Provides prompt formatting methods: get_vqa_prompt(), get_caption_prompt(), get_imagenet_prompt(), get_hateful_memes_prompt()
- Supports DDP wrapping for distributed evaluation
Usage
Create with model configuration args; used by all evaluate_* functions.
Code Reference
- Source
- Repository https://github.com/mlfoundations/open_flamingo, File:
open_flamingo/eval/models/open_flamingo.pyLines L13-334
- Base class
open_flamingo/eval/eval_model.pyLines L8-89
- Signature
class EvalModel(BaseEvalModel):
def __init__(self, model_args: Dict[str, str]):
"""
Required model_args keys:
vision_encoder_path, vision_encoder_pretrained, lm_path,
lm_tokenizer_path, checkpoint_path, cross_attn_every_n_layers,
precision, device (optional)
"""
def get_outputs(self, batch_text, batch_images, min_generation_length,
max_generation_length, num_beams, length_penalty) -> List[str]: ...
def get_rank_classifications(self, batch_text, batch_images, all_class_names,
use_cache, normalize_length) -> torch.Tensor: ...
def get_vqa_prompt(self, question: str, answer: str = None) -> str: ...
def get_caption_prompt(self, caption: str = None) -> str: ...
def get_imagenet_prompt(self, label: str = None) -> str: ...
def get_hateful_memes_prompt(self, text: str, label: str = None) -> str: ...
- Import
from open_flamingo.eval.models.open_flamingo import EvalModel
I/O Contract
Constructor
| Parameter | Type | Required | Description |
|---|---|---|---|
| model_args | Dict[str, str] | Yes | Model configuration dictionary containing vision_encoder_path, lm_path, checkpoint_path, etc. |
get_outputs
Inputs:
| Parameter | Type | Required | Description |
|---|---|---|---|
| batch_text | List[str] | Yes | Texts with <image> placeholders
|
| batch_images | List[List[PIL.Image]] | Yes | Batch of image lists |
| min_generation_length | int | No | Minimum tokens to generate |
| max_generation_length | int | No | Maximum tokens to generate |
| num_beams | int | No | Beam search width |
| length_penalty | float | No | Length penalty for generation |
Outputs: List[str] -- decoded generated text
get_rank_classifications
Inputs:
| Parameter | Type | Required | Description |
|---|---|---|---|
| batch_text | List[str] | Yes | Texts with <image> placeholders
|
| batch_images | List[List[PIL.Image]] | Yes | Batch of image lists |
| all_class_names | List[str] | Yes | Class names to score |
| use_cache | bool | No | Whether to use KV-cache optimization |
| normalize_length | bool | No | Whether to normalize log-probs by token length |
Outputs: torch.Tensor shape (B, num_classes) with log-probabilities
Usage Examples
from open_flamingo.eval.models.open_flamingo import EvalModel
# Configure the model
model_args = {
"vision_encoder_path": "ViT-L-14",
"vision_encoder_pretrained": "openai",
"lm_path": "anas-awadalla/mpt-1b-redpajama-200b",
"lm_tokenizer_path": "anas-awadalla/mpt-1b-redpajama-200b",
"checkpoint_path": "/path/to/checkpoint.pt",
"cross_attn_every_n_layers": 1,
"precision": "float16",
}
# Create the evaluation model
eval_model = EvalModel(model_args)
# Generate captions for a batch of images
batch_text = ["<image>Output:", "<image>Output:"]
batch_images = [[image1], [image2]]
captions = eval_model.get_outputs(
batch_text=batch_text,
batch_images=batch_images,
min_generation_length=0,
max_generation_length=20,
num_beams=3,
length_penalty=-2.0,
)
# Format a VQA prompt
vqa_prompt = eval_model.get_vqa_prompt(
question="What color is the car?",
answer="red",
)
Related Pages
Principle:Mlfoundations_Open_flamingo_Evaluation_Model_Abstraction
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment