Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Mlfoundations Open flamingo RICES

From Leeroopedia


Template:Metadata

Overview

Concrete tool for retrieving visually similar in-context examples using CLIP features provided by the OpenFlamingo evaluation module.

Description

The RICES class pre-computes CLIP image features for all training examples (or loads cached features). The find() method computes CLIP features for query images, performs cosine similarity search against the training set, and returns the top-k most similar examples sorted by similarity (most similar last). Features can be pre-computed and cached to disk using the cache_rices_features.py script for efficiency.

Usage

Initialize RICES with a training dataset before evaluation; call find() for each test batch.

Code Reference

Source: Repository https://github.com/mlfoundations/open_flamingo, File: open_flamingo/eval/rices.py Lines L1-95

Signature:

class RICES:
    def __init__(
        self,
        dataset,
        device,
        batch_size,
        vision_encoder_path: str = "ViT-B-32",
        vision_encoder_pretrained: str = "openai",
        cached_features=None,
    ):
        """
        Args:
            dataset: training dataset to select examples from
            device: device for vision encoder
            batch_size: batch size for feature computation
            vision_encoder_path: CLIP model path (default "ViT-B-32")
            vision_encoder_pretrained: CLIP pretrained dataset (default "openai")
            cached_features: pre-computed feature tensor (optional, skips encoding)
        """

    def find(self, batch: List[PIL.Image], num_examples: int) -> List[List[dict]]:
        """
        Retrieve top num_examples most similar examples for each query image.

        Args:
            batch: list of query PIL images
            num_examples: number of examples to retrieve per query
        Returns:
            List of lists of dataset items, sorted by similarity (most similar last)
        """

Import:

from open_flamingo.eval.rices import RICES

I/O Contract

Constructor Inputs

Parameter Type Required Description
dataset Dataset Yes Training dataset for example pool
device torch.device Yes Device for CLIP encoder
batch_size int Yes Batch size for feature computation
vision_encoder_path str No CLIP model path (default "ViT-B-32")
vision_encoder_pretrained str No CLIP pretrained dataset (default "openai")
cached_features Tensor No Pre-computed features (skips encoding if provided)

find() Inputs

Parameter Type Required Description
batch List[PIL.Image] Yes Query images
num_examples int Yes Number of examples to retrieve per query

Outputs

  • Constructor: RICES object with pre-computed CLIP features for the training dataset.
  • find(): List[List[dict]] where each inner list contains num_examples dataset items sorted by cosine similarity (most similar last).

Usage Examples

import torch
from open_flamingo.eval.rices import RICES

# Initialize RICES with cached features for efficiency
cached_features = torch.load("cached_rices_features.pt")
rices = RICES(
    dataset=train_dataset,
    device=torch.device("cuda"),
    batch_size=64,
    vision_encoder_path="ViT-B-32",
    vision_encoder_pretrained="openai",
    cached_features=cached_features,
)

# Retrieve top-2 similar examples for a batch of query images
query_images = [test_dataset[i]["image"] for i in range(8)]
demonstrations = rices.find(batch=query_images, num_examples=2)

# demonstrations[0] contains the 2 most similar training examples
# for query_images[0], sorted with most similar last
for i, demo_list in enumerate(demonstrations):
    print(f"Query {i}: retrieved {len(demo_list)} demonstrations")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment