Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:OpenGVLab InternVL LLaVA LLaMA Model

From Leeroopedia


Knowledge Sources
Domains Model_Architecture, Language_Model, Multimodal
Last Updated 2026-02-07 14:00 GMT

Overview

This module defines the LLaMA-based LLaVA model by combining the LLaMA language model with multimodal vision capabilities through multiple inheritance from LlavaMetaModel and LlavaMetaForCausalLM.

Description

The llava_llama.py module defines three classes that constitute the primary LLaVA model implementation:

LlavaConfig: Extends LlamaConfig with a custom model_type = "llava_llama" for HuggingFace model registry identification.

LlavaLlamaModel: Inherits from both LlavaMetaModel and LlamaModel via multiple inheritance. LlavaMetaModel provides the vision tower integration and multimodal projector, while LlamaModel provides the transformer decoder backbone. Uses LlavaConfig as its config class.

LlavaLlamaForCausalLM: Inherits from LlamaForCausalLM and LlavaMetaForCausalLM. This is the main model class used for inference and training. Key implementation details:

  • __init__: Initializes LlavaLlamaModel as the backbone and creates the lm_head linear layer
  • forward: Calls prepare_inputs_labels_for_multimodal (from LlavaMetaForCausalLM) to fuse image features into text embeddings before passing through the LLaMA decoder. Computes cross-entropy loss with shifted labels for next-token prediction
  • prepare_inputs_for_generation: Handles autoregressive generation by passing only the last input_id when past_key_values exist, and switching between input_ids and inputs_embeds based on generation step

The module registers the model with HuggingFace's AutoConfig and AutoModelForCausalLM registries under the "llava_llama" model type, enabling automatic model loading.

Usage

This class is instantiated by the model builder (builder.py) when loading LLaVA models. It is the primary model architecture used for most LLaVA configurations in the InternVL repository.

Code Reference

Source Location

Signature

class LlavaConfig(LlamaConfig):
    model_type = "llava_llama"

class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
    config_class = LlavaConfig
    def __init__(self, config: LlamaConfig): ...

class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
    config_class = LlavaConfig
    def __init__(self, config): ...
    def get_model(self) -> LlavaLlamaModel: ...
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        images: Optional[torch.FloatTensor] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]: ...
    def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
                                       attention_mask=None, inputs_embeds=None, **kwargs) -> dict: ...

Import

from llava.model.language_model.llava_llama import LlavaConfig, LlavaLlamaModel, LlavaLlamaForCausalLM

I/O Contract

Inputs

Name Type Required Description
input_ids torch.LongTensor No Token IDs (mutually exclusive with inputs_embeds in generation)
attention_mask torch.Tensor No Attention mask for padding
past_key_values List[torch.FloatTensor] No Cached key-value pairs for autoregressive generation
inputs_embeds torch.FloatTensor No Pre-computed input embeddings (used after multimodal fusion)
labels torch.LongTensor No Target token IDs for loss computation (shifted internally)
images torch.FloatTensor No Image tensor(s) to fuse into text embeddings via multimodal projector
use_cache bool No Whether to return past_key_values for caching
return_dict bool No Whether to return a CausalLMOutputWithPast object

Outputs

Name Type Description
loss torch.FloatTensor Cross-entropy loss (only when labels provided)
logits torch.FloatTensor Next-token prediction logits of shape (batch, seq_len, vocab_size)
past_key_values tuple Cached key-value pairs for generation (when use_cache=True)
hidden_states tuple Hidden states from all layers (when output_hidden_states=True)
attentions tuple Attention weights from all layers (when output_attentions=True)

Usage Examples

Basic Usage

from llava.model.language_model.llava_llama import LlavaLlamaForCausalLM

# Model is typically loaded via builder.py, not directly instantiated
# But can be loaded via HuggingFace AutoModel:
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
    "path/to/llava-model",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
)

# Forward pass with images
outputs = model(
    input_ids=input_ids,
    images=image_tensor,
    labels=labels,
)
loss = outputs.loss

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment