Implementation:Mlfoundations Open flamingo RICES
Overview
Concrete tool for retrieving visually similar in-context examples using CLIP features provided by the OpenFlamingo evaluation module.
Description
The RICES class pre-computes CLIP image features for all training examples (or loads cached features). The find() method computes CLIP features for query images, performs cosine similarity search against the training set, and returns the top-k most similar examples sorted by similarity (most similar last). Features can be pre-computed and cached to disk using the cache_rices_features.py script for efficiency.
Usage
Initialize RICES with a training dataset before evaluation; call find() for each test batch.
Code Reference
Source: Repository https://github.com/mlfoundations/open_flamingo, File: open_flamingo/eval/rices.py Lines L1-95
Signature:
class RICES:
def __init__(
self,
dataset,
device,
batch_size,
vision_encoder_path: str = "ViT-B-32",
vision_encoder_pretrained: str = "openai",
cached_features=None,
):
"""
Args:
dataset: training dataset to select examples from
device: device for vision encoder
batch_size: batch size for feature computation
vision_encoder_path: CLIP model path (default "ViT-B-32")
vision_encoder_pretrained: CLIP pretrained dataset (default "openai")
cached_features: pre-computed feature tensor (optional, skips encoding)
"""
def find(self, batch: List[PIL.Image], num_examples: int) -> List[List[dict]]:
"""
Retrieve top num_examples most similar examples for each query image.
Args:
batch: list of query PIL images
num_examples: number of examples to retrieve per query
Returns:
List of lists of dataset items, sorted by similarity (most similar last)
"""
Import:
from open_flamingo.eval.rices import RICES
I/O Contract
Constructor Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
| dataset | Dataset | Yes | Training dataset for example pool |
| device | torch.device | Yes | Device for CLIP encoder |
| batch_size | int | Yes | Batch size for feature computation |
| vision_encoder_path | str | No | CLIP model path (default "ViT-B-32") |
| vision_encoder_pretrained | str | No | CLIP pretrained dataset (default "openai") |
| cached_features | Tensor | No | Pre-computed features (skips encoding if provided) |
find() Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
| batch | List[PIL.Image] | Yes | Query images |
| num_examples | int | Yes | Number of examples to retrieve per query |
Outputs
- Constructor: RICES object with pre-computed CLIP features for the training dataset.
find():List[List[dict]]where each inner list containsnum_examplesdataset items sorted by cosine similarity (most similar last).
Usage Examples
import torch
from open_flamingo.eval.rices import RICES
# Initialize RICES with cached features for efficiency
cached_features = torch.load("cached_rices_features.pt")
rices = RICES(
dataset=train_dataset,
device=torch.device("cuda"),
batch_size=64,
vision_encoder_path="ViT-B-32",
vision_encoder_pretrained="openai",
cached_features=cached_features,
)
# Retrieve top-2 similar examples for a batch of query images
query_images = [test_dataset[i]["image"] for i in range(8)]
demonstrations = rices.find(batch=query_images, num_examples=2)
# demonstrations[0] contains the 2 most similar training examples
# for query_images[0], sorted with most similar last
for i, demo_list in enumerate(demonstrations):
print(f"Query {i}: retrieved {len(demo_list)} demonstrations")