Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax VLM Causal LM

From Leeroopedia


Knowledge Sources
Domains Model_Architecture, Inference
Last Updated 2026-02-08 00:00 GMT

Overview

Provides the Vision-Language Model (VLM) inference wrapper that extends FlashCausalLM to handle multimodal inputs combining images and text for models such as LLaVA-NeXT, Idefics2, and PaLIGemma.

Description

This module defines the batch handling and model wrapper for vision-language models that use a causal language model backbone.

Key classes:

  • VlmCausalLMBatch (extends FlashCausalLMBatch) - A dataclass batch type that adds vision-specific fields: pixel_values, pixel_attention_mask, and image_sizes. Key methods include:
    • batch_tokenized_inputs - Processes multimodal request chunks (text and images), runs image processing through the model's processor, and replaces image chunks with the correct number of image tokens based on model type.
    • from_pb - Constructs the batch from protobuf, handling both text tokenization and image preprocessing.
    • concatenate and filter - Override parent methods to clear pixel values after the first forward pass (images are only needed during prefill).
  • VlmCausalLM (extends FlashCausalLM) - The main model wrapper that:
    • Initializes a processor (e.g., AutoProcessor) for image preprocessing alongside the text model.
    • Overrides forward to pass pixel values, attention masks, and image sizes to the underlying model during prefill, then clears them for subsequent decode steps.
    • Supports speculative decoding with proper tensor reshaping for speculative IDs.
    • Explicitly disables prefix caching (not yet compatible with VLMs).

Key functions:

  • image_text_replacement - Generates the appropriate image token replacement string based on model type (Idefics2, LLaVA-NeXT, PaLIGemma).
  • get_number_of_features - Calculates the total number of visual features for a LLaVA-NeXT image including base features, unpadded features, and newline features.
  • get_unpadded_features - Computes unpadded feature counts based on original image dimensions and patch grid.

Usage

VlmCausalLM is the primary wrapper used when the LoRax server loads a vision-language model. It is instantiated from the model registry with a specific model_class (e.g., LlavaNextForConditionalGeneration) and handles the complete inference lifecycle: batch construction from protobuf, multimodal forward pass, and generation token loop.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/models/vlm_causal_lm.py
  • Lines: 1-420

Signature

class VlmCausalLMBatch(FlashCausalLMBatch):
    pixel_values: Optional[List[torch.Tensor]]
    pixel_attention_mask: Optional[List[torch.Tensor]]
    image_sizes: Optional[List[Tuple[int, int]]]

    @classmethod
    def concatenate(cls, batches):
        ...
    def filter(self, request_ids: List[int]):
        ...
    @classmethod
    def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
        ...
    @classmethod
    def from_pb(cls, pb, tokenizer, tokenizers, processor, config, dtype, device) -> "VlmCausalLMBatch":
        ...

class VlmCausalLM(FlashCausalLM):
    def __init__(
        self,
        model_id: str,
        model_class,
        *,
        processor_class=AutoProcessor,
        processor_kwargs=None,
        batch_class=VlmCausalLMBatch,
        adapter_id: str,
        adapter_source: str,
        revision: Optional[str] = None,
        dtype: Optional[torch.dtype] = None,
        trust_remote_code: bool = False,
        **kwargs,
    ):
        ...
    @property
    def batch_type(self) -> Type[VlmCausalLMBatch]:
        ...
    def max_past(self) -> Optional[int]:
        ...
    def forward(self, batch: VlmCausalLMBatch, adapter_data: AdapterBatchData) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        ...

Import

from lorax_server.models.vlm_causal_lm import VlmCausalLM, VlmCausalLMBatch

I/O Contract

Inputs

Name Type Required Description
model_id str Yes HuggingFace model identifier
model_class type Yes The PyTorch model class to instantiate (e.g., LlavaNextForConditionalGeneration)
processor_class type No Processor class for image preprocessing (default: AutoProcessor)
adapter_id str Yes Adapter identifier for LoRA loading
adapter_source str Yes Source of adapter weights (e.g., "hub")
batch VlmCausalLMBatch Yes Batch containing text tokens and optional pixel values
adapter_data AdapterBatchData Yes Adapter weight data for the current batch

Outputs

Name Type Description
logits torch.Tensor Next-token logits over the vocabulary
speculative_logits Optional[torch.Tensor] Speculative decoding logits

Usage Examples

# Internal LoRax server usage
from lorax_server.models.vlm_causal_lm import VlmCausalLM

# Instantiated by model registry for VLM model types
# vlm = VlmCausalLM(
#     model_id="llava-hf/llava-v1.6-mistral-7b-hf",
#     model_class=LlavaNextForConditionalGeneration,
#     adapter_id="",
#     adapter_source="hub",
# )

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment