Implementation:Predibase Lorax SigLIP Vision Encoder
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, Inference, Vision_Language_Model |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Provides a tensor-parallel implementation of the SigLIP (Sigmoid Loss for Language Image Pre-Training) vision encoder for use as the vision component within vision-language models served by the LoRAX framework.
Description
This module implements Google's SigLIP vision encoder with tensor parallelism. SigLIP is similar to CLIP but uses a sigmoid-based contrastive loss instead of softmax, and notably does not include a CLS token in its vision transformer -- relying instead on multihead attention pooling or using all patch embeddings directly.
Key classes:
- SiglipVisionEmbeddings -- Converts pixel values into patch embeddings via a 2D convolution with "valid" padding. Unlike CLIP, there is no CLS class token; only patch embeddings and position embeddings (via
TensorParallelEmbedding) are used. The number of positions equals the number of patches.
- SiglipAttention -- Multi-headed attention with separate Q, K, V projections via
TensorParallelColumnLinearand an output projection viaTensorParallelRowLinear. Post-matmul scaling is applied (scale after QK dot product rather than before). Attention weights are upcast to fp32 for softmax stability.
- SiglipMLP -- Two-layer feed-forward network with configurable activation (
config.hidden_act), usingTensorParallelColumnLinearforfc1andTensorParallelRowLinearforfc2.
- SiglipEncoderLayer -- Single transformer encoder layer with pre-norm architecture. Applies layer norm before self-attention and MLP, with residual connections around each.
- SiglipMultiheadAttentionPoolingHead -- Multihead attention pooling head that uses a learned probe vector as the query to attend over the encoder output. Produces a single pooled representation by selecting the first (and only) output token. Includes a residual MLP on top.
- SiglipEncoder -- Stack of
SiglipEncoderLayerlayers forming the vision transformer encoder backbone.
- SiglipVisionTransformer -- Top-level vision transformer that chains embeddings and encoder. Returns
BaseModelOutputWithPoolingwith the last hidden state.
Utility functions: trunc_normal_tf_, variance_scaling_, lecun_normal_, default_flax_embed_init -- initialization utilities matching TensorFlow/JAX conventions.
Usage
Used internally by the LoRAX server as the vision encoder component for vision-language models that use SigLIP (e.g., PaliGemma). The SiglipVisionTransformer processes image inputs into visual token embeddings consumed by the language model. Loaded via the model registry.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/lorax_server/models/custom_modeling/siglip.py - Lines: 1-386
Signature
class SiglipVisionTransformer(nn.Module):
def __init__(self, prefix, config: SiglipVisionConfig, weights):
...
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
) -> BaseModelOutputWithPooling:
Import
from lorax_server.models.custom_modeling.siglip import SiglipVisionTransformer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| pixel_values | Optional[torch.FloatTensor] |
Yes | Image pixel values of shape (batch_size, channels, height, width)
|
Outputs
| Name | Type | Description |
|---|---|---|
| last_hidden_state | torch.Tensor |
Sequence of vision token embeddings of shape (batch_size, num_patches, hidden_size). Unlike CLIP, no CLS token is included.
|
Usage Examples
# Internal usage within LoRAX server for VLM models
from lorax_server.models.custom_modeling.siglip import SiglipVisionTransformer
# Instantiated as part of a vision-language model pipeline (e.g., PaliGemma)
# Processes image pixel values into visual token embeddings