Implementation:Mit han lab Llm awq LlavaLlamaForCausalLM
| Knowledge Sources | |
|---|---|
| Domains | Vision, NLP, Multimodal |
| Last Updated | 2026-02-15 00:00 GMT |
Overview
Implements the concrete LLaVA-LLaMA multimodal model by combining the LLaVA vision-language mixin with a LLaMA causal language model backbone.
Description
This module defines two classes. LlavaLlamaModel is a simple composition of LlavaMetaModel and Transformer, inheriting both the vision tower/projector initialization from LLaVA and the LLaMA transformer architecture.
LlavaLlamaForCausalLM inherits from LlamaForCausalLM and LlavaMetaForCausalLM, creating a full multimodal causal language model. It initializes a LlavaLlamaModel backbone, a language model head, and exposes the forward method that handles two distinct multimodal embedding strategies: (1) a special-token-based approach via prepare_inputs_labels_for_multimodal inherited from the base LLaVA arch, and (2) a default patch-token-based approach via default_inputs_embeds_for_multimodal. The default method runs the vision tower (with CLIP support), applies the mm_projector, optionally adds NEFTune noise, and replaces image patch token positions in the embedding sequence with projected image features. The forward method dispatches between these strategies based on the special_token flag and supports both full-sequence and incremental (start_pos) decoding modes. prepare_inputs_for_generation wraps the parent class method to propagate images through the generation pipeline.
Usage
Import LlavaLlamaForCausalLM when you need a ready-to-use LLaVA model backed by a LLaMA decoder for inference or fine-tuning with image-text inputs.
Code Reference
Source Location
- Repository: Mit_han_lab_Llm_awq
- File: tinychat/models/llava_llama.py
- Lines: 1-282
Signature
class LlavaLlamaModel(LlavaMetaModel, Transformer):
def __init__(self, config): ...
class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
def __init__(self, config, dev="cuda"): ...
def get_model(self) -> LlavaLlamaModel: ...
def default_inputs_embeds_for_multimodal(
self, input_ids, inputs_embeds, images
) -> torch.Tensor: ...
def forward(
self, input_ids=None, start_pos=None, attention_mask=None,
position_ids=None, past_key_values=None, inputs_embeds=None,
labels=None, use_cache=None, output_attentions=None,
output_hidden_states=None, images=None, return_dict=None,
special_token=False
) -> Union[Tuple, CausalLMOutputWithPast]: ...
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
) -> dict: ...
Import
from tinychat.models.llava_llama import LlavaLlamaForCausalLM
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.LongTensor | No | Token IDs for the input; may contain image patch token placeholders |
| start_pos | int | No | Starting position for incremental decoding; None for full-sequence mode |
| attention_mask | torch.Tensor | No | Attention mask for the sequence |
| position_ids | torch.LongTensor | No | Position IDs for rotary embeddings |
| past_key_values | List[torch.FloatTensor] | No | KV cache from previous generation steps |
| inputs_embeds | torch.FloatTensor | No | Pre-computed embeddings; bypasses multimodal embedding if provided |
| labels | torch.LongTensor | No | Target labels for training loss computation |
| images | torch.FloatTensor | No | Image tensor(s) to encode; shape (B, C, H, W) |
| special_token | bool | No | If True, uses special-token strategy; otherwise uses default patch-token replacement |
Outputs
| Name | Type | Description |
|---|---|---|
| output | Union[Tuple, CausalLMOutputWithPast] | Model output containing logits and optionally loss, hidden states, and attentions |
Usage Examples
Forward pass with images
from tinychat.models.llava_llama import LlavaLlamaForCausalLM
model = LlavaLlamaForCausalLM(config, dev="cuda")
output = model.forward(
input_ids=input_ids,
images=image_tensor,
attention_mask=attention_mask,
labels=labels,
)
loss = output.loss
Incremental decoding with start_pos
output = model.forward(
input_ids=next_token_ids,
start_pos=current_position,
images=None, # Images already processed in prefill
)
logits = output