Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Mlfoundations Open flamingo EvalModel

From Leeroopedia


Template:Metadata

Overview

Concrete tool providing the OpenFlamingo evaluation model wrapper with KV-cache classification and task-specific prompt formatting provided by the OpenFlamingo evaluation module.

Description

EvalModel extends BaseEvalModel to wrap a Flamingo model for evaluation. Key features:

  1. Loads model via create_model_and_transforms + checkpoint
  2. get_outputs() generates text from batched image-text inputs
  3. get_rank_classifications() uses KV-cache to efficiently score multiple class names
  4. Provides prompt formatting methods: get_vqa_prompt(), get_caption_prompt(), get_imagenet_prompt(), get_hateful_memes_prompt()
  5. Supports DDP wrapping for distributed evaluation

Usage

Create with model configuration args; used by all evaluate_* functions.

Code Reference

Source
Repository https://github.com/mlfoundations/open_flamingo, File: open_flamingo/eval/models/open_flamingo.py Lines L13-334
Base class
open_flamingo/eval/eval_model.py Lines L8-89
Signature
class EvalModel(BaseEvalModel):
    def __init__(self, model_args: Dict[str, str]):
        """
        Required model_args keys:
            vision_encoder_path, vision_encoder_pretrained, lm_path,
            lm_tokenizer_path, checkpoint_path, cross_attn_every_n_layers,
            precision, device (optional)
        """

    def get_outputs(self, batch_text, batch_images, min_generation_length,
                    max_generation_length, num_beams, length_penalty) -> List[str]: ...

    def get_rank_classifications(self, batch_text, batch_images, all_class_names,
                                  use_cache, normalize_length) -> torch.Tensor: ...

    def get_vqa_prompt(self, question: str, answer: str = None) -> str: ...
    def get_caption_prompt(self, caption: str = None) -> str: ...
    def get_imagenet_prompt(self, label: str = None) -> str: ...
    def get_hateful_memes_prompt(self, text: str, label: str = None) -> str: ...
Import
from open_flamingo.eval.models.open_flamingo import EvalModel

I/O Contract

Constructor

Parameter Type Required Description
model_args Dict[str, str] Yes Model configuration dictionary containing vision_encoder_path, lm_path, checkpoint_path, etc.

get_outputs

Inputs:

Parameter Type Required Description
batch_text List[str] Yes Texts with <image> placeholders
batch_images List[List[PIL.Image]] Yes Batch of image lists
min_generation_length int No Minimum tokens to generate
max_generation_length int No Maximum tokens to generate
num_beams int No Beam search width
length_penalty float No Length penalty for generation

Outputs: List[str] -- decoded generated text

get_rank_classifications

Inputs:

Parameter Type Required Description
batch_text List[str] Yes Texts with <image> placeholders
batch_images List[List[PIL.Image]] Yes Batch of image lists
all_class_names List[str] Yes Class names to score
use_cache bool No Whether to use KV-cache optimization
normalize_length bool No Whether to normalize log-probs by token length

Outputs: torch.Tensor shape (B, num_classes) with log-probabilities

Usage Examples

from open_flamingo.eval.models.open_flamingo import EvalModel

# Configure the model
model_args = {
    "vision_encoder_path": "ViT-L-14",
    "vision_encoder_pretrained": "openai",
    "lm_path": "anas-awadalla/mpt-1b-redpajama-200b",
    "lm_tokenizer_path": "anas-awadalla/mpt-1b-redpajama-200b",
    "checkpoint_path": "/path/to/checkpoint.pt",
    "cross_attn_every_n_layers": 1,
    "precision": "float16",
}

# Create the evaluation model
eval_model = EvalModel(model_args)

# Generate captions for a batch of images
batch_text = ["<image>Output:", "<image>Output:"]
batch_images = [[image1], [image2]]
captions = eval_model.get_outputs(
    batch_text=batch_text,
    batch_images=batch_images,
    min_generation_length=0,
    max_generation_length=20,
    num_beams=3,
    length_penalty=-2.0,
)

# Format a VQA prompt
vqa_prompt = eval_model.get_vqa_prompt(
    question="What color is the car?",
    answer="red",
)

Related Pages

Principle:Mlfoundations_Open_flamingo_Evaluation_Model_Abstraction

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment