Implementation:Predibase Lorax VLM Causal LM
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Provides the Vision-Language Model (VLM) inference wrapper that extends FlashCausalLM to handle multimodal inputs combining images and text for models such as LLaVA-NeXT, Idefics2, and PaLIGemma.
Description
This module defines the batch handling and model wrapper for vision-language models that use a causal language model backbone.
Key classes:
- VlmCausalLMBatch (extends
FlashCausalLMBatch) - A dataclass batch type that adds vision-specific fields:pixel_values,pixel_attention_mask, andimage_sizes. Key methods include:batch_tokenized_inputs- Processes multimodal request chunks (text and images), runs image processing through the model's processor, and replaces image chunks with the correct number of image tokens based on model type.from_pb- Constructs the batch from protobuf, handling both text tokenization and image preprocessing.concatenateandfilter- Override parent methods to clear pixel values after the first forward pass (images are only needed during prefill).
- VlmCausalLM (extends
FlashCausalLM) - The main model wrapper that:- Initializes a processor (e.g.,
AutoProcessor) for image preprocessing alongside the text model. - Overrides
forwardto pass pixel values, attention masks, and image sizes to the underlying model during prefill, then clears them for subsequent decode steps. - Supports speculative decoding with proper tensor reshaping for speculative IDs.
- Explicitly disables prefix caching (not yet compatible with VLMs).
- Initializes a processor (e.g.,
Key functions:
- image_text_replacement - Generates the appropriate image token replacement string based on model type (Idefics2, LLaVA-NeXT, PaLIGemma).
- get_number_of_features - Calculates the total number of visual features for a LLaVA-NeXT image including base features, unpadded features, and newline features.
- get_unpadded_features - Computes unpadded feature counts based on original image dimensions and patch grid.
Usage
VlmCausalLM is the primary wrapper used when the LoRax server loads a vision-language model. It is instantiated from the model registry with a specific model_class (e.g., LlavaNextForConditionalGeneration) and handles the complete inference lifecycle: batch construction from protobuf, multimodal forward pass, and generation token loop.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/lorax_server/models/vlm_causal_lm.py - Lines: 1-420
Signature
class VlmCausalLMBatch(FlashCausalLMBatch):
pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]]
@classmethod
def concatenate(cls, batches):
...
def filter(self, request_ids: List[int]):
...
@classmethod
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
...
@classmethod
def from_pb(cls, pb, tokenizer, tokenizers, processor, config, dtype, device) -> "VlmCausalLMBatch":
...
class VlmCausalLM(FlashCausalLM):
def __init__(
self,
model_id: str,
model_class,
*,
processor_class=AutoProcessor,
processor_kwargs=None,
batch_class=VlmCausalLMBatch,
adapter_id: str,
adapter_source: str,
revision: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
**kwargs,
):
...
@property
def batch_type(self) -> Type[VlmCausalLMBatch]:
...
def max_past(self) -> Optional[int]:
...
def forward(self, batch: VlmCausalLMBatch, adapter_data: AdapterBatchData) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
...
Import
from lorax_server.models.vlm_causal_lm import VlmCausalLM, VlmCausalLMBatch
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model_id | str | Yes | HuggingFace model identifier |
| model_class | type | Yes | The PyTorch model class to instantiate (e.g., LlavaNextForConditionalGeneration) |
| processor_class | type | No | Processor class for image preprocessing (default: AutoProcessor) |
| adapter_id | str | Yes | Adapter identifier for LoRA loading |
| adapter_source | str | Yes | Source of adapter weights (e.g., "hub") |
| batch | VlmCausalLMBatch | Yes | Batch containing text tokens and optional pixel values |
| adapter_data | AdapterBatchData | Yes | Adapter weight data for the current batch |
Outputs
| Name | Type | Description |
|---|---|---|
| logits | torch.Tensor | Next-token logits over the vocabulary |
| speculative_logits | Optional[torch.Tensor] | Speculative decoding logits |
Usage Examples
# Internal LoRax server usage
from lorax_server.models.vlm_causal_lm import VlmCausalLM
# Instantiated by model registry for VLM model types
# vlm = VlmCausalLM(
# model_id="llava-hf/llava-v1.6-mistral-7b-hf",
# model_class=LlavaNextForConditionalGeneration,
# adapter_id="",
# adapter_source="hub",
# )