Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Openai CLIP Zeroshot Classifier

From Leeroopedia
Knowledge Sources
Domains Vision, NLP, Zero_Shot_Learning, Classification
Last Updated 2026-02-13 22:00 GMT

Overview

Pattern documentation for the zeroshot_classifier() function that constructs prompt-ensembled text classifier weights from CLIP text embeddings.

Description

The zeroshot_classifier() function is a user-defined pattern (not part of the CLIP package) demonstrated in the Prompt Engineering notebook (cell 15). It takes a list of class names and prompt templates, and constructs a classification weight matrix by encoding each template-class combination with CLIP's text encoder, normalizing, averaging per class, and normalizing again.

This function uses clip.tokenize() and model.encode_text() internally, operating under torch.no_grad() for efficiency. The result is a weight matrix of shape [embed_dim, num_classes] that can be used for zero-shot classification via dot product.

Usage

Use this function after defining class names and templates (Prompt Engineering step) and loading the CLIP model. The returned weight matrix is used in the Accuracy Evaluation step.

Code Reference

Source Location

  • Repository: OpenAI CLIP
  • File: notebooks/Prompt_Engineering_for_ImageNet.ipynb (cell 15)
  • Internally calls: clip/clip.py:L205-245 (tokenize), clip/model.py:L343-356 (encode_text)

Interface Specification

def zeroshot_classifier(classnames: List[str], templates: List[str]) -> torch.Tensor:
    """Construct zero-shot classifier weights from text embeddings.

    For each class, encodes all template-expanded texts, L2-normalizes,
    averages, and L2-normalizes again.

    Parameters
    ----------
    classnames : List[str]
        List of class names (e.g., 1000 ImageNet classes).

    templates : List[str]
        List of prompt template strings with {} placeholder
        (e.g., 80 ImageNet templates).

    Returns
    -------
    torch.Tensor
        Zero-shot classifier weights, shape [embed_dim, num_classes],
        on CUDA. Each column is the L2-normalized ensembled text
        embedding for one class.

    Notes
    -----
    Requires `model` (CLIP) to be available in scope (closure).
    Operates under torch.no_grad().
    """
    with torch.no_grad():
        zeroshot_weights = []
        for classname in tqdm(classnames):
            texts = [template.format(classname) for template in templates]
            texts = clip.tokenize(texts).cuda()
            class_embeddings = model.encode_text(texts)
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
    return zeroshot_weights

Import

# User-defined function — no package import
# Requires: clip, torch, tqdm, and a loaded CLIP model
import clip
import torch
from tqdm import tqdm

I/O Contract

Inputs

Name Type Required Description
classnames List[str] Yes List of class names (e.g., 1000 ImageNet classes, disambiguated)
templates List[str] Yes List of prompt template strings with {} placeholder (e.g., 80 templates)
model CLIP (closure) Yes Loaded CLIP model, available in the enclosing scope

Outputs

Name Type Description
zeroshot_weights torch.Tensor Classifier weight matrix, shape [embed_dim, num_classes], on CUDA. Each column is the L2-normalized ensembled text prototype for one class.

Usage Examples

Building ImageNet Classifier Weights

import clip
import torch
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# Define classes and templates (typically loaded from data)
imagenet_classes = ["tench", "goldfish", "great white shark", ...]  # 1000 classes
imagenet_templates = [
    "a photo of a {}.",
    "a bad photo of the {}.",
    "a sculpture of a {}.",
    # ... 80 templates
]

def zeroshot_classifier(classnames, templates):
    with torch.no_grad():
        zeroshot_weights = []
        for classname in tqdm(classnames):
            texts = [template.format(classname) for template in templates]
            texts = clip.tokenize(texts).to(device)
            class_embeddings = model.encode_text(texts)
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
    return zeroshot_weights

# Build classifier weights
zeroshot_weights = zeroshot_classifier(imagenet_classes, imagenet_templates)
# zeroshot_weights.shape: [512, 1000] for ViT-B/32

# Classify an image
image = preprocess(Image.open("photo.jpg")).unsqueeze(0).to(device)
with torch.no_grad():
    image_features = model.encode_image(image)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    logits = 100.0 * image_features @ zeroshot_weights
    probs = logits.softmax(dim=-1)

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment