Implementation:Openai CLIP Zeroshot Classifier
| Knowledge Sources | |
|---|---|
| Domains | Vision, NLP, Zero_Shot_Learning, Classification |
| Last Updated | 2026-02-13 22:00 GMT |
Overview
Pattern documentation for the zeroshot_classifier() function that constructs prompt-ensembled text classifier weights from CLIP text embeddings.
Description
The zeroshot_classifier() function is a user-defined pattern (not part of the CLIP package) demonstrated in the Prompt Engineering notebook (cell 15). It takes a list of class names and prompt templates, and constructs a classification weight matrix by encoding each template-class combination with CLIP's text encoder, normalizing, averaging per class, and normalizing again.
This function uses clip.tokenize() and model.encode_text() internally, operating under torch.no_grad() for efficiency. The result is a weight matrix of shape [embed_dim, num_classes] that can be used for zero-shot classification via dot product.
Usage
Use this function after defining class names and templates (Prompt Engineering step) and loading the CLIP model. The returned weight matrix is used in the Accuracy Evaluation step.
Code Reference
Source Location
- Repository: OpenAI CLIP
- File: notebooks/Prompt_Engineering_for_ImageNet.ipynb (cell 15)
- Internally calls: clip/clip.py:L205-245 (tokenize), clip/model.py:L343-356 (encode_text)
Interface Specification
def zeroshot_classifier(classnames: List[str], templates: List[str]) -> torch.Tensor:
"""Construct zero-shot classifier weights from text embeddings.
For each class, encodes all template-expanded texts, L2-normalizes,
averages, and L2-normalizes again.
Parameters
----------
classnames : List[str]
List of class names (e.g., 1000 ImageNet classes).
templates : List[str]
List of prompt template strings with {} placeholder
(e.g., 80 ImageNet templates).
Returns
-------
torch.Tensor
Zero-shot classifier weights, shape [embed_dim, num_classes],
on CUDA. Each column is the L2-normalized ensembled text
embedding for one class.
Notes
-----
Requires `model` (CLIP) to be available in scope (closure).
Operates under torch.no_grad().
"""
with torch.no_grad():
zeroshot_weights = []
for classname in tqdm(classnames):
texts = [template.format(classname) for template in templates]
texts = clip.tokenize(texts).cuda()
class_embeddings = model.encode_text(texts)
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
return zeroshot_weights
Import
# User-defined function — no package import
# Requires: clip, torch, tqdm, and a loaded CLIP model
import clip
import torch
from tqdm import tqdm
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| classnames | List[str] | Yes | List of class names (e.g., 1000 ImageNet classes, disambiguated) |
| templates | List[str] | Yes | List of prompt template strings with {} placeholder (e.g., 80 templates) |
| model | CLIP (closure) | Yes | Loaded CLIP model, available in the enclosing scope |
Outputs
| Name | Type | Description |
|---|---|---|
| zeroshot_weights | torch.Tensor | Classifier weight matrix, shape [embed_dim, num_classes], on CUDA. Each column is the L2-normalized ensembled text prototype for one class. |
Usage Examples
Building ImageNet Classifier Weights
import clip
import torch
from tqdm import tqdm
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
# Define classes and templates (typically loaded from data)
imagenet_classes = ["tench", "goldfish", "great white shark", ...] # 1000 classes
imagenet_templates = [
"a photo of a {}.",
"a bad photo of the {}.",
"a sculpture of a {}.",
# ... 80 templates
]
def zeroshot_classifier(classnames, templates):
with torch.no_grad():
zeroshot_weights = []
for classname in tqdm(classnames):
texts = [template.format(classname) for template in templates]
texts = clip.tokenize(texts).to(device)
class_embeddings = model.encode_text(texts)
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
return zeroshot_weights
# Build classifier weights
zeroshot_weights = zeroshot_classifier(imagenet_classes, imagenet_templates)
# zeroshot_weights.shape: [512, 1000] for ViT-B/32
# Classify an image
image = preprocess(Image.open("photo.jpg")).unsqueeze(0).to(device)
with torch.no_grad():
image_features = model.encode_image(image)
image_features /= image_features.norm(dim=-1, keepdim=True)
logits = 100.0 * image_features @ zeroshot_weights
probs = logits.softmax(dim=-1)