Implementation:OpenGVLab InternVL LLaVA LLaMA Model
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, Language_Model, Multimodal |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
This module defines the LLaMA-based LLaVA model by combining the LLaMA language model with multimodal vision capabilities through multiple inheritance from LlavaMetaModel and LlavaMetaForCausalLM.
Description
The llava_llama.py module defines three classes that constitute the primary LLaVA model implementation:
LlavaConfig: Extends LlamaConfig with a custom model_type = "llava_llama" for HuggingFace model registry identification.
LlavaLlamaModel: Inherits from both LlavaMetaModel and LlamaModel via multiple inheritance. LlavaMetaModel provides the vision tower integration and multimodal projector, while LlamaModel provides the transformer decoder backbone. Uses LlavaConfig as its config class.
LlavaLlamaForCausalLM: Inherits from LlamaForCausalLM and LlavaMetaForCausalLM. This is the main model class used for inference and training. Key implementation details:
- __init__: Initializes
LlavaLlamaModelas the backbone and creates thelm_headlinear layer - forward: Calls
prepare_inputs_labels_for_multimodal(fromLlavaMetaForCausalLM) to fuse image features into text embeddings before passing through the LLaMA decoder. Computes cross-entropy loss with shifted labels for next-token prediction - prepare_inputs_for_generation: Handles autoregressive generation by passing only the last input_id when past_key_values exist, and switching between input_ids and inputs_embeds based on generation step
The module registers the model with HuggingFace's AutoConfig and AutoModelForCausalLM registries under the "llava_llama" model type, enabling automatic model loading.
Usage
This class is instantiated by the model builder (builder.py) when loading LLaVA models. It is the primary model architecture used for most LLaVA configurations in the InternVL repository.
Code Reference
Source Location
- Repository: OpenGVLab_InternVL
- File: internvl_chat_llava/llava/model/language_model/llava_llama.py
- Lines: 1-140
Signature
class LlavaConfig(LlamaConfig):
model_type = "llava_llama"
class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
config_class = LlavaConfig
def __init__(self, config: LlamaConfig): ...
class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
config_class = LlavaConfig
def __init__(self, config): ...
def get_model(self) -> LlavaLlamaModel: ...
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]: ...
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
attention_mask=None, inputs_embeds=None, **kwargs) -> dict: ...
Import
from llava.model.language_model.llava_llama import LlavaConfig, LlavaLlamaModel, LlavaLlamaForCausalLM
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.LongTensor | No | Token IDs (mutually exclusive with inputs_embeds in generation) |
| attention_mask | torch.Tensor | No | Attention mask for padding |
| past_key_values | List[torch.FloatTensor] | No | Cached key-value pairs for autoregressive generation |
| inputs_embeds | torch.FloatTensor | No | Pre-computed input embeddings (used after multimodal fusion) |
| labels | torch.LongTensor | No | Target token IDs for loss computation (shifted internally) |
| images | torch.FloatTensor | No | Image tensor(s) to fuse into text embeddings via multimodal projector |
| use_cache | bool | No | Whether to return past_key_values for caching |
| return_dict | bool | No | Whether to return a CausalLMOutputWithPast object |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | torch.FloatTensor | Cross-entropy loss (only when labels provided) |
| logits | torch.FloatTensor | Next-token prediction logits of shape (batch, seq_len, vocab_size) |
| past_key_values | tuple | Cached key-value pairs for generation (when use_cache=True) |
| hidden_states | tuple | Hidden states from all layers (when output_hidden_states=True) |
| attentions | tuple | Attention weights from all layers (when output_attentions=True) |
Usage Examples
Basic Usage
from llava.model.language_model.llava_llama import LlavaLlamaForCausalLM
# Model is typically loaded via builder.py, not directly instantiated
# But can be loaded via HuggingFace AutoModel:
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"path/to/llava-model",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
# Forward pass with images
outputs = model(
input_ids=input_ids,
images=image_tensor,
labels=labels,
)
loss = outputs.loss