Implementation:Openai CLIP CLIP Forward
| Knowledge Sources | |
|---|---|
| Domains | Vision, NLP, Contrastive_Learning |
| Last Updated | 2026-02-13 22:00 GMT |
Overview
Concrete tool for computing scaled cosine similarity logits between image and text inputs provided by the CLIP model class.
Description
The CLIP.forward() method performs the complete CLIP inference pipeline in one call: it encodes images through the vision encoder, encodes text through the text transformer, L2-normalizes both feature sets, and computes temperature-scaled cosine similarity logits. It returns two tensors: logits_per_image and logits_per_text (transposes of each other).
The learned logit_scale parameter (initialized to ln(1/0.07) ≈ 2.66) is exponentiated and applied as a scalar multiplier to the cosine similarities, controlling the sharpness of the output distribution.
Usage
Call this method when you want the full CLIP similarity computation in a single forward pass. For zero-shot classification, apply softmax(dim=-1) to logits_per_image to get per-class probabilities. If you only need embeddings (not logits), use encode_image() and encode_text() separately.
Code Reference
Source Location
- Repository: OpenAI CLIP
- File: clip/model.py
- Lines: L358-372
Signature
def forward(
self,
image: torch.Tensor,
text: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute scaled cosine similarity between images and texts.
Parameters
----------
image : torch.Tensor
Batch of preprocessed images, shape [B_img, 3, n_px, n_px].
text : torch.Tensor
Batch of tokenized texts, shape [B_txt, 77].
Returns
-------
logits_per_image : torch.Tensor
Shape [B_img, B_txt]. Scaled cosine similarity scores.
logits_per_text : torch.Tensor
Shape [B_txt, B_img]. Transpose of logits_per_image.
"""
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# Normalized features
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
# Cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
return logits_per_image, logits_per_text
Import
import clip
model, preprocess = clip.load("ViT-B/32")
# Then call: logits_per_image, logits_per_text = model(image, text)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| image | torch.Tensor | Yes | Batch of preprocessed images, shape [B_img, 3, n_px, n_px] |
| text | torch.Tensor | Yes | Batch of tokenized texts, shape [B_txt, 77]. Output of clip.tokenize() |
Outputs
| Name | Type | Description |
|---|---|---|
| logits_per_image | torch.Tensor | Shape [B_img, B_txt]. Temperature-scaled cosine similarity. Apply softmax(dim=-1) for class probabilities. |
| logits_per_text | torch.Tensor | Shape [B_txt, B_img]. Transpose of logits_per_image. Apply softmax(dim=-1) for image retrieval probabilities. |
Usage Examples
Zero-Shot Image Classification
import clip
import torch
from PIL import Image
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
# Prepare image
image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
# Prepare text labels
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)
# Compute similarity
with torch.no_grad():
logits_per_image, logits_per_text = model(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
print("Label probs:", probs)
# e.g., [[0.9927, 0.0041, 0.0032]]
Batch Classification
import clip
import torch
from PIL import Image
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
# Prepare multiple images
images = torch.stack([
preprocess(Image.open(f"image_{i}.jpg")) for i in range(10)
]).to(device)
# Prepare class descriptions
classes = ["cat", "dog", "bird", "car", "plane"]
text = clip.tokenize([f"a photo of a {c}" for c in classes]).to(device)
with torch.no_grad():
logits_per_image, _ = model(images, text)
predictions = logits_per_image.argmax(dim=-1)
for i, pred in enumerate(predictions):
print(f"Image {i}: {classes[pred]}")