Implementation:FlagOpen FlagEmbedding MMRet CLIP Model
| Knowledge Sources | |
|---|---|
| Domains | Computer Vision, Multi-Modal Learning, Vision-Language Models, Neural Networks |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
An extended CLIP model implementation for multi-modal retrieval that supports image, text, and image-text combined encoding for vision-language tasks.
Description
This implementation provides a comprehensive CLIP-based architecture adapted from HuggingFace Transformers with extensions for multi-modal retrieval. The model inherits the standard CLIP architecture (vision transformer + text transformer with contrastive learning) and adds specialized encoding methods for flexible input handling: image-only encoding for visual search, text-only encoding for semantic search, and multi-modal encoding that combines image and text representations for composed queries.
The implementation includes complete CLIP architecture components with vision and text encoders, projection layers for shared embedding space, support for multiple attention mechanisms (eager, SDPA, Flash Attention 2), gradient checkpointing for memory efficiency, and contrastive learning with temperature scaling. Additional features include a convenient data processing pipeline with automatic format detection, normalization of embeddings for cosine similarity, integration with CLIPProcessor for preprocessing, and support for both training and inference modes.
The model supports three key use cases: single-modality retrieval (image→image or text→text), cross-modal retrieval (image→text or text→image), and composed retrieval (image+text→image), making it suitable for diverse vision-language applications.
Usage
Use this model for multi-modal retrieval tasks where you need to encode images, text, or combinations of both into a shared embedding space for similarity search or cross-modal matching.
Code Reference
Source Location
- Repository: FlagOpen_FlagEmbedding
- File: research/BGE_VL/modeling_MMRet_CLIP.py
- Lines: 1-1678
Signature
class CLIPModel(CLIPPreTrainedModel):
def __init__(self, config: CLIPConfig)
def get_text_features(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor
def get_image_features(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor
def encode_image(self, images) -> torch.Tensor
def encode_text(self, text) -> torch.Tensor
def encode_multimodal(self, images, text) -> torch.Tensor
def encode(self, images=None, text=None) -> torch.Tensor
def data_process(self, images=None, text=None) -> Tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CLIPOutput]
Import
from modeling_MMRet_CLIP import CLIPModel
from transformers import CLIPProcessor
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.LongTensor | No | Text token IDs (batch_size, seq_len) |
| pixel_values | torch.FloatTensor | No | Image tensors (batch_size, 3, height, width) |
| attention_mask | torch.Tensor | No | Text attention mask |
| position_ids | torch.LongTensor | No | Position IDs for text tokens |
| images | str/list/PIL.Image | No | Image path(s) or PIL Image(s) for encoding |
| text | str/list | No | Text string(s) for encoding |
| return_loss | bool | No | Compute contrastive loss (default: False) |
| output_attentions | bool | No | Return attention weights |
| output_hidden_states | bool | No | Return hidden states |
| return_dict | bool | No | Return ModelOutput object |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | Optional[torch.FloatTensor] | Contrastive loss (if return_loss=True) |
| logits_per_image | torch.FloatTensor | Image-text similarity scores (batch_size_image, batch_size_text) |
| logits_per_text | torch.FloatTensor | Text-image similarity scores (batch_size_text, batch_size_image) |
| text_embeds | torch.FloatTensor | Text embeddings (batch_size, projection_dim) |
| image_embeds | torch.FloatTensor | Image embeddings (batch_size, projection_dim) |
| embeddings | torch.Tensor | Normalized embeddings from encode() methods |
Architecture Components
Vision Encoder
Architecture:
- Vision Transformer (ViT)
- Patch embedding: Image → patches → embeddings
- Class token + positional embeddings
- Transformer encoder layers
- Post-layer normalization
Configuration (typical):
- Patch size: 14x14 or 16x16
- Hidden size: 768 (base) or 1024 (large)
- Number of layers: 12 (base) or 24 (large)
- Attention heads: 12 (base) or 16 (large)
Text Encoder
Architecture:
- Transformer decoder (causal attention)
- Token + position embeddings
- Transformer encoder layers
- Final layer normalization
- EOS token pooling
Configuration (typical):
- Vocab size: 49408
- Hidden size: 512 (base) or 768 (large)
- Number of layers: 12
- Attention heads: 8 (base) or 12 (large)
- Max position embeddings: 77
Projection Layers
- visual_projection: Projects vision_embed_dim → projection_dim
- text_projection: Projects text_embed_dim → projection_dim
- projection_dim: Typically 512 or 768
- No bias: Projections are linear without bias terms
Logit Scale
- Learnable temperature parameter
- Initialized to log(1/0.07) ≈ 2.66
- Scales similarity scores for contrastive learning
Encoding Methods
Image Encoding
def encode_image(self, images):
embeddings = self.get_image_features(images)
embeddings = F.normalize(embeddings, dim=-1)
return embeddings
Process: 1. Vision transformer encoding 2. CLS token extraction 3. Visual projection 4. L2 normalization
Text Encoding
def encode_text(self, text):
embeddings = self.get_text_features(**text)
embeddings = F.normalize(embeddings, dim=-1)
return embeddings
Process: 1. Text transformer encoding 2. EOS token pooling 3. Text projection 4. L2 normalization
Multi-Modal Encoding
def encode_multimodal(self, images, text):
text_embeddings = self.get_text_features(**text)
image_embeddings = self.get_image_features(images)
embeddings = text_embeddings + image_embeddings
embeddings = F.normalize(embeddings, dim=-1)
return embeddings
Process: 1. Encode text and image separately 2. Sum embeddings (before projection) 3. L2 normalization 4. Enables composed queries (e.g., "image + text modification")
Unified Encoding Interface
def encode(self, images=None, text=None):
images, text, data_type = self.data_process(images, text)
if data_type == "images":
return self.encode_image(images)
elif data_type == "text":
return self.encode_text(text)
elif data_type == "multimodal":
return self.encode_multimodal(images, text)
Automatically determines encoding mode based on inputs.
Data Processing
Input Format Handling
The data_process method handles various input formats:
Image Inputs:
- Single path string: "/path/to/image.jpg"
- List of paths: ["/path/1.jpg", "/path/2.jpg"]
- PIL.Image object(s)
- Automatically converts to pixel_values tensor
Text Inputs:
- Single string: "a photo of a cat"
- List of strings: ["cat", "dog", "bird"]
- Automatically tokenizes and pads
Multi-Modal Inputs:
- Images and text must be same type (both str or both list)
- Lists must have same length
- Combined for composed queries
Preprocessing Pipeline
# Set processor
model.set_processor("openai/clip-vit-base-patch32")
# Automatic preprocessing
images, text, data_type = model.data_process(
images="/path/to/image.jpg",
text="a photo of a cat"
)
# Returns:
# - images: preprocessed pixel_values
# - text: tokenized input_ids + attention_mask
# - data_type: "multimodal"
Contrastive Learning
Loss Function
def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
caption_loss = F.cross_entropy(
similarity,
torch.arange(len(similarity), device=similarity.device)
)
image_loss = F.cross_entropy(
similarity.t(),
torch.arange(len(similarity), device=similarity.device)
)
return (caption_loss + image_loss) / 2.0
Symmetric cross-entropy loss:
- Image-to-text: Predict correct text for each image
- Text-to-image: Predict correct image for each text
- Diagonal elements are positive pairs
Similarity Computation
# Normalized embeddings
image_embeds = image_embeds / ||image_embeds||
text_embeds = text_embeds / ||text_embeds||
# Scaled cosine similarity
logit_scale = exp(self.logit_scale)
logits_per_text = text_embeds @ image_embeds.T * logit_scale
logits_per_image = logits_per_text.T
- Attention Mechanisms ==
Eager Attention (Default)
Standard PyTorch attention implementation with explicit Q, K, V computation.
SDPA (Scaled Dot-Product Attention)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states,
attn_mask=attn_mask,
dropout_p=self.dropout,
scale=self.scale
)
Optimized fused attention kernel, faster than eager mode.
Flash Attention 2
attn_output = _flash_attention_forward(
query_states, key_states, value_states,
attention_mask,
q_len,
dropout=dropout_rate,
is_causal=causal_attention_mask is not None
)
Memory-efficient attention, especially for long sequences.
Usage Examples
Basic Image-Text Matching
from modeling_MMRet_CLIP import CLIPModel
from transformers import CLIPProcessor
from PIL import Image
# Load model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
model.set_processor("openai/clip-vit-base-patch32")
model.eval()
# Encode image
image = Image.open("cat.jpg")
image_emb = model.encode(images=image) # (1, 512)
# Encode text
texts = ["a cat", "a dog", "a bird"]
text_emb = model.encode(text=texts) # (3, 512)
# Compute similarity
similarity = (image_emb @ text_emb.T).squeeze() # (3,)
probs = similarity.softmax(dim=0)
print(f"Probabilities: {probs}") # [0.8, 0.15, 0.05]
Multi-Modal Query
# Reference image + text modification
ref_image = "blue_shirt.jpg"
modification = "make it red"
# Encode composed query
query_emb = model.encode(images=ref_image, text=modification) # (1, 512)
# Encode candidate images
candidates = ["red_shirt.jpg", "blue_pants.jpg", "green_shirt.jpg"]
candidate_embs = model.encode(images=candidates) # (3, 512)
# Find best match
similarity = (query_emb @ candidate_embs.T).squeeze()
best_idx = similarity.argmax()
print(f"Best match: {candidates[best_idx]}")
Batch Processing
# Batch encode images
image_paths = [f"image_{i}.jpg" for i in range(100)]
batch_size = 32
all_embeddings = []
for i in range(0, len(image_paths), batch_size):
batch_paths = image_paths[i:i+batch_size]
batch_emb = model.encode(images=batch_paths)
all_embeddings.append(batch_emb)
all_embeddings = torch.cat(all_embeddings, dim=0) # (100, 512)
Zero-Shot Classification
# Define classes
classes = ["cat", "dog", "bird", "car", "tree"]
templates = [f"a photo of a {c}" for c in classes]
# Encode class descriptions
class_embs = model.encode(text=templates) # (5, 512)
# Classify image
test_image = "test.jpg"
image_emb = model.encode(images=test_image) # (1, 512)
# Compute similarity
logits = (image_emb @ class_embs.T).squeeze() * 100 # Scale for better softmax
probs = logits.softmax(dim=0)
# Get prediction
pred_idx = probs.argmax()
print(f"Prediction: {classes[pred_idx]} (confidence: {probs[pred_idx]:.2%})")
Training with Contrastive Loss
# Prepare batch
images = torch.randn(32, 3, 224, 224) # 32 images
captions = ["caption for image 1", "caption for image 2", ...] # 32 captions
# Tokenize captions
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
text_inputs = processor(text=captions, return_tensors="pt", padding=True)
# Forward pass
outputs = model(
pixel_values=images,
input_ids=text_inputs["input_ids"],
attention_mask=text_inputs["attention_mask"],
return_loss=True
)
loss = outputs.loss
loss.backward()
optimizer.step()
Cross-Modal Retrieval
# Build image database
image_database = ["img1.jpg", "img2.jpg", ...] # 10000 images
image_embs = []
for img in image_database:
emb = model.encode(images=img)
image_embs.append(emb)
image_embs = torch.cat(image_embs, dim=0) # (10000, 512)
# Text query
query = "a sunset over the ocean"
query_emb = model.encode(text=query) # (1, 512)
# Retrieve top-k images
k = 5
similarity = (query_emb @ image_embs.T).squeeze() # (10000,)
top_k_indices = similarity.topk(k).indices
print(f"Top {k} matches:")
for idx in top_k_indices:
print(f" {image_database[idx]}: {similarity[idx]:.4f}")
Model Variants
The implementation includes several CLIP-based model classes:
CLIPTextModel
Text encoder only, for text-only tasks.
CLIPVisionModel
Vision encoder only, for image-only tasks.
CLIPTextModelWithProjection
Text encoder with projection layer.
CLIPVisionModelWithProjection
Vision encoder with projection layer.
CLIPForImageClassification
CLIP vision encoder with classification head.
CLIPModel
Full CLIP model with both encoders (main model).
Configuration
Key Parameters
- projection_dim: Shared embedding space dimension (512)
- text_embed_dim: Text encoder hidden size (512)
- vision_embed_dim: Vision encoder hidden size (768)
- logit_scale_init_value: Initial temperature (2.6592)
- _attn_implementation: "eager", "sdpa", or "flash_attention_2"
Model Sizes
CLIP-ViT-B/32:
- Vision: ViT-B/32 (86M params)
- Text: Transformer (63M params)
- Total: ~149M params
CLIP-ViT-L/14:
- Vision: ViT-L/14 (304M params)
- Text: Transformer (123M params)
- Total: ~427M params
Performance Optimization
Gradient Checkpointing
model.gradient_checkpointing_enable()
Reduces memory usage during training at the cost of ~20% slower training.
Mixed Precision Training
from torch.cuda.amp import autocast
with autocast():
outputs = model(pixel_values=images, input_ids=input_ids)
loss = outputs.loss
Speeds up training and reduces memory usage.
Flash Attention
config = CLIPConfig.from_pretrained("openai/clip-vit-base-patch32")
config._attn_implementation = "flash_attention_2"
model = CLIPModel(config)
Significantly faster attention computation.
Limitations
- Fixed input resolution for vision encoder (typically 224x224)
- Maximum text length of 77 tokens
- Requires paired image-text data for training
- May not generalize well to out-of-distribution data
- Biases from training data (e.g., ImageNet)