Implementation:Predibase Lorax CLIP Vision Encoder
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, Inference, Vision_Language_Model |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Provides a tensor-parallel implementation of the CLIP (Contrastive Language-Image Pre-Training) vision and text encoders for use as vision components within vision-language models (VLMs) served by the LoRAX framework.
Description
This module implements OpenAI's CLIP architecture with tensor parallelism, primarily used as the vision encoder component for multi-modal models such as LLaVA. It includes both the vision transformer and text transformer, though the vision encoder is the primary component used in the LoRAX serving pipeline.
Key classes:
- CLIPVisionEmbeddings -- Converts pixel values into patch embeddings using a 2D convolution with kernel/stride equal to the patch size. Adds a learnable CLS class embedding token prepended to the sequence, plus learned position embeddings via
TensorParallelEmbedding.
- CLIPTextEmbeddings -- Token and position embeddings for the text encoder. Uses standard
nn.Embeddingfor token and position lookups.
- CLIPAttention -- Multi-headed attention using a fused QKV projection via
TensorParallelColumnLinear.load_multi(loading separate q_proj, k_proj, v_proj into a single tensor) and an output projection viaTensorParallelRowLinear. Supports both causal and non-causal attention masks.
- CLIPMLP -- Two-layer feed-forward network with configurable activation, using
TensorParallelColumnLinearfor the up-projection andTensorParallelRowLinearfor the down-projection.
- CLIPEncoderLayer -- Single transformer encoder layer combining self-attention, MLP, and layer normalization with residual connections. Implements pre-norm architecture.
- CLIPEncoder -- Stack of
CLIPEncoderLayerlayers forming the core transformer encoder.
- CLIPVisionTransformer -- Vision transformer that processes pixel values through patch embeddings, a pre-layer norm, and the encoder stack. Returns
BaseModelOutputWithPoolingcontaining the last hidden state.
- CLIPVisionModel -- High-level vision model wrapper around
CLIPVisionTransformer.
- CLIPTextTransformer -- Text transformer for encoding text inputs with causal attention masking.
- CLIPModel -- Full CLIP model combining both vision and text transformers with projection layers and logit scaling for contrastive learning.
Usage
Used internally by the LoRAX server as the vision encoder component for vision-language models (e.g., LLaVA). The CLIPVisionTransformer processes image inputs into visual token embeddings that are then consumed by the language model. Loaded via the model registry.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/lorax_server/models/custom_modeling/clip.py - Lines: 1-761
Signature
class CLIPVisionTransformer(nn.Module):
def __init__(self, prefix, config: CLIPVisionConfig, weights):
...
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
) -> BaseModelOutputWithPooling:
Import
from lorax_server.models.custom_modeling.clip import CLIPVisionTransformer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| pixel_values | Optional[torch.FloatTensor] |
Yes | Image pixel values of shape (batch_size, channels, height, width)
|
Outputs
| Name | Type | Description |
|---|---|---|
| last_hidden_state | torch.Tensor |
Sequence of vision token embeddings of shape (batch_size, num_patches + 1, hidden_size). The first token is the CLS token.
|
Usage Examples
# Internal usage within LoRAX server for VLM models
from lorax_server.models.custom_modeling.clip import CLIPVisionTransformer
# Instantiated as part of a vision-language model pipeline
# Processes image pixel values into visual token embeddings