Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:FlagOpen FlagEmbedding MMRet CLIP Model

From Leeroopedia


Knowledge Sources
Domains Computer Vision, Multi-Modal Learning, Vision-Language Models, Neural Networks
Last Updated 2026-02-09 00:00 GMT

Overview

An extended CLIP model implementation for multi-modal retrieval that supports image, text, and image-text combined encoding for vision-language tasks.

Description

This implementation provides a comprehensive CLIP-based architecture adapted from HuggingFace Transformers with extensions for multi-modal retrieval. The model inherits the standard CLIP architecture (vision transformer + text transformer with contrastive learning) and adds specialized encoding methods for flexible input handling: image-only encoding for visual search, text-only encoding for semantic search, and multi-modal encoding that combines image and text representations for composed queries.

The implementation includes complete CLIP architecture components with vision and text encoders, projection layers for shared embedding space, support for multiple attention mechanisms (eager, SDPA, Flash Attention 2), gradient checkpointing for memory efficiency, and contrastive learning with temperature scaling. Additional features include a convenient data processing pipeline with automatic format detection, normalization of embeddings for cosine similarity, integration with CLIPProcessor for preprocessing, and support for both training and inference modes.

The model supports three key use cases: single-modality retrieval (image→image or text→text), cross-modal retrieval (image→text or text→image), and composed retrieval (image+text→image), making it suitable for diverse vision-language applications.

Usage

Use this model for multi-modal retrieval tasks where you need to encode images, text, or combinations of both into a shared embedding space for similarity search or cross-modal matching.

Code Reference

Source Location

Signature

class CLIPModel(CLIPPreTrainedModel):
    def __init__(self, config: CLIPConfig)

    def get_text_features(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> torch.FloatTensor

    def get_image_features(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> torch.FloatTensor

    def encode_image(self, images) -> torch.Tensor
    def encode_text(self, text) -> torch.Tensor
    def encode_multimodal(self, images, text) -> torch.Tensor
    def encode(self, images=None, text=None) -> torch.Tensor
    def data_process(self, images=None, text=None) -> Tuple

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        return_loss: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CLIPOutput]

Import

from modeling_MMRet_CLIP import CLIPModel
from transformers import CLIPProcessor

I/O Contract

Inputs

Name Type Required Description
input_ids torch.LongTensor No Text token IDs (batch_size, seq_len)
pixel_values torch.FloatTensor No Image tensors (batch_size, 3, height, width)
attention_mask torch.Tensor No Text attention mask
position_ids torch.LongTensor No Position IDs for text tokens
images str/list/PIL.Image No Image path(s) or PIL Image(s) for encoding
text str/list No Text string(s) for encoding
return_loss bool No Compute contrastive loss (default: False)
output_attentions bool No Return attention weights
output_hidden_states bool No Return hidden states
return_dict bool No Return ModelOutput object

Outputs

Name Type Description
loss Optional[torch.FloatTensor] Contrastive loss (if return_loss=True)
logits_per_image torch.FloatTensor Image-text similarity scores (batch_size_image, batch_size_text)
logits_per_text torch.FloatTensor Text-image similarity scores (batch_size_text, batch_size_image)
text_embeds torch.FloatTensor Text embeddings (batch_size, projection_dim)
image_embeds torch.FloatTensor Image embeddings (batch_size, projection_dim)
embeddings torch.Tensor Normalized embeddings from encode() methods

Architecture Components

Vision Encoder

Architecture:

  • Vision Transformer (ViT)
  • Patch embedding: Image → patches → embeddings
  • Class token + positional embeddings
  • Transformer encoder layers
  • Post-layer normalization

Configuration (typical):

  • Patch size: 14x14 or 16x16
  • Hidden size: 768 (base) or 1024 (large)
  • Number of layers: 12 (base) or 24 (large)
  • Attention heads: 12 (base) or 16 (large)

Text Encoder

Architecture:

  • Transformer decoder (causal attention)
  • Token + position embeddings
  • Transformer encoder layers
  • Final layer normalization
  • EOS token pooling

Configuration (typical):

  • Vocab size: 49408
  • Hidden size: 512 (base) or 768 (large)
  • Number of layers: 12
  • Attention heads: 8 (base) or 12 (large)
  • Max position embeddings: 77

Projection Layers

  • visual_projection: Projects vision_embed_dim → projection_dim
  • text_projection: Projects text_embed_dim → projection_dim
  • projection_dim: Typically 512 or 768
  • No bias: Projections are linear without bias terms

Logit Scale

  • Learnable temperature parameter
  • Initialized to log(1/0.07) ≈ 2.66
  • Scales similarity scores for contrastive learning

Encoding Methods

Image Encoding

def encode_image(self, images):
    embeddings = self.get_image_features(images)
    embeddings = F.normalize(embeddings, dim=-1)
    return embeddings

Process: 1. Vision transformer encoding 2. CLS token extraction 3. Visual projection 4. L2 normalization

Text Encoding

def encode_text(self, text):
    embeddings = self.get_text_features(**text)
    embeddings = F.normalize(embeddings, dim=-1)
    return embeddings

Process: 1. Text transformer encoding 2. EOS token pooling 3. Text projection 4. L2 normalization

Multi-Modal Encoding

def encode_multimodal(self, images, text):
    text_embeddings = self.get_text_features(**text)
    image_embeddings = self.get_image_features(images)
    embeddings = text_embeddings + image_embeddings
    embeddings = F.normalize(embeddings, dim=-1)
    return embeddings

Process: 1. Encode text and image separately 2. Sum embeddings (before projection) 3. L2 normalization 4. Enables composed queries (e.g., "image + text modification")

Unified Encoding Interface

def encode(self, images=None, text=None):
    images, text, data_type = self.data_process(images, text)
    if data_type == "images":
        return self.encode_image(images)
    elif data_type == "text":
        return self.encode_text(text)
    elif data_type == "multimodal":
        return self.encode_multimodal(images, text)

Automatically determines encoding mode based on inputs.

Data Processing

Input Format Handling

The data_process method handles various input formats:

Image Inputs:

  • Single path string: "/path/to/image.jpg"
  • List of paths: ["/path/1.jpg", "/path/2.jpg"]
  • PIL.Image object(s)
  • Automatically converts to pixel_values tensor

Text Inputs:

  • Single string: "a photo of a cat"
  • List of strings: ["cat", "dog", "bird"]
  • Automatically tokenizes and pads

Multi-Modal Inputs:

  • Images and text must be same type (both str or both list)
  • Lists must have same length
  • Combined for composed queries

Preprocessing Pipeline

# Set processor
model.set_processor("openai/clip-vit-base-patch32")

# Automatic preprocessing
images, text, data_type = model.data_process(
    images="/path/to/image.jpg",
    text="a photo of a cat"
)
# Returns:
# - images: preprocessed pixel_values
# - text: tokenized input_ids + attention_mask
# - data_type: "multimodal"

Contrastive Learning

Loss Function

def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
    caption_loss = F.cross_entropy(
        similarity,
        torch.arange(len(similarity), device=similarity.device)
    )
    image_loss = F.cross_entropy(
        similarity.t(),
        torch.arange(len(similarity), device=similarity.device)
    )
    return (caption_loss + image_loss) / 2.0

Symmetric cross-entropy loss:

  • Image-to-text: Predict correct text for each image
  • Text-to-image: Predict correct image for each text
  • Diagonal elements are positive pairs

Similarity Computation

# Normalized embeddings
image_embeds = image_embeds / ||image_embeds||
text_embeds = text_embeds / ||text_embeds||

# Scaled cosine similarity
logit_scale = exp(self.logit_scale)
logits_per_text = text_embeds @ image_embeds.T * logit_scale
logits_per_image = logits_per_text.T
    1. Attention Mechanisms ==

Eager Attention (Default)

Standard PyTorch attention implementation with explicit Q, K, V computation.

SDPA (Scaled Dot-Product Attention)

attn_output = torch.nn.functional.scaled_dot_product_attention(
    query_states, key_states, value_states,
    attn_mask=attn_mask,
    dropout_p=self.dropout,
    scale=self.scale
)

Optimized fused attention kernel, faster than eager mode.

Flash Attention 2

attn_output = _flash_attention_forward(
    query_states, key_states, value_states,
    attention_mask,
    q_len,
    dropout=dropout_rate,
    is_causal=causal_attention_mask is not None
)

Memory-efficient attention, especially for long sequences.

Usage Examples

Basic Image-Text Matching

from modeling_MMRet_CLIP import CLIPModel
from transformers import CLIPProcessor
from PIL import Image

# Load model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
model.set_processor("openai/clip-vit-base-patch32")
model.eval()

# Encode image
image = Image.open("cat.jpg")
image_emb = model.encode(images=image)  # (1, 512)

# Encode text
texts = ["a cat", "a dog", "a bird"]
text_emb = model.encode(text=texts)  # (3, 512)

# Compute similarity
similarity = (image_emb @ text_emb.T).squeeze()  # (3,)
probs = similarity.softmax(dim=0)
print(f"Probabilities: {probs}")  # [0.8, 0.15, 0.05]

Multi-Modal Query

# Reference image + text modification
ref_image = "blue_shirt.jpg"
modification = "make it red"

# Encode composed query
query_emb = model.encode(images=ref_image, text=modification)  # (1, 512)

# Encode candidate images
candidates = ["red_shirt.jpg", "blue_pants.jpg", "green_shirt.jpg"]
candidate_embs = model.encode(images=candidates)  # (3, 512)

# Find best match
similarity = (query_emb @ candidate_embs.T).squeeze()
best_idx = similarity.argmax()
print(f"Best match: {candidates[best_idx]}")

Batch Processing

# Batch encode images
image_paths = [f"image_{i}.jpg" for i in range(100)]
batch_size = 32

all_embeddings = []
for i in range(0, len(image_paths), batch_size):
    batch_paths = image_paths[i:i+batch_size]
    batch_emb = model.encode(images=batch_paths)
    all_embeddings.append(batch_emb)

all_embeddings = torch.cat(all_embeddings, dim=0)  # (100, 512)

Zero-Shot Classification

# Define classes
classes = ["cat", "dog", "bird", "car", "tree"]
templates = [f"a photo of a {c}" for c in classes]

# Encode class descriptions
class_embs = model.encode(text=templates)  # (5, 512)

# Classify image
test_image = "test.jpg"
image_emb = model.encode(images=test_image)  # (1, 512)

# Compute similarity
logits = (image_emb @ class_embs.T).squeeze() * 100  # Scale for better softmax
probs = logits.softmax(dim=0)

# Get prediction
pred_idx = probs.argmax()
print(f"Prediction: {classes[pred_idx]} (confidence: {probs[pred_idx]:.2%})")

Training with Contrastive Loss

# Prepare batch
images = torch.randn(32, 3, 224, 224)  # 32 images
captions = ["caption for image 1", "caption for image 2", ...]  # 32 captions

# Tokenize captions
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
text_inputs = processor(text=captions, return_tensors="pt", padding=True)

# Forward pass
outputs = model(
    pixel_values=images,
    input_ids=text_inputs["input_ids"],
    attention_mask=text_inputs["attention_mask"],
    return_loss=True
)

loss = outputs.loss
loss.backward()
optimizer.step()

Cross-Modal Retrieval

# Build image database
image_database = ["img1.jpg", "img2.jpg", ...]  # 10000 images
image_embs = []
for img in image_database:
    emb = model.encode(images=img)
    image_embs.append(emb)
image_embs = torch.cat(image_embs, dim=0)  # (10000, 512)

# Text query
query = "a sunset over the ocean"
query_emb = model.encode(text=query)  # (1, 512)

# Retrieve top-k images
k = 5
similarity = (query_emb @ image_embs.T).squeeze()  # (10000,)
top_k_indices = similarity.topk(k).indices

print(f"Top {k} matches:")
for idx in top_k_indices:
    print(f"  {image_database[idx]}: {similarity[idx]:.4f}")

Model Variants

The implementation includes several CLIP-based model classes:

CLIPTextModel

Text encoder only, for text-only tasks.

CLIPVisionModel

Vision encoder only, for image-only tasks.

CLIPTextModelWithProjection

Text encoder with projection layer.

CLIPVisionModelWithProjection

Vision encoder with projection layer.

CLIPForImageClassification

CLIP vision encoder with classification head.

CLIPModel

Full CLIP model with both encoders (main model).

Configuration

Key Parameters

  • projection_dim: Shared embedding space dimension (512)
  • text_embed_dim: Text encoder hidden size (512)
  • vision_embed_dim: Vision encoder hidden size (768)
  • logit_scale_init_value: Initial temperature (2.6592)
  • _attn_implementation: "eager", "sdpa", or "flash_attention_2"

Model Sizes

CLIP-ViT-B/32:

  • Vision: ViT-B/32 (86M params)
  • Text: Transformer (63M params)
  • Total: ~149M params

CLIP-ViT-L/14:

  • Vision: ViT-L/14 (304M params)
  • Text: Transformer (123M params)
  • Total: ~427M params

Performance Optimization

Gradient Checkpointing

model.gradient_checkpointing_enable()

Reduces memory usage during training at the cost of ~20% slower training.

Mixed Precision Training

from torch.cuda.amp import autocast

with autocast():
    outputs = model(pixel_values=images, input_ids=input_ids)
    loss = outputs.loss

Speeds up training and reduces memory usage.

Flash Attention

config = CLIPConfig.from_pretrained("openai/clip-vit-base-patch32")
config._attn_implementation = "flash_attention_2"
model = CLIPModel(config)

Significantly faster attention computation.

Limitations

  • Fixed input resolution for vision encoder (typically 224x224)
  • Maximum text length of 77 tokens
  • Requires paired image-text data for training
  • May not generalize well to out-of-distribution data
  • Biases from training data (e.g., ImageNet)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment