Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mit han lab Llm awq LlavaLlamaForCausalLM

From Leeroopedia
Knowledge Sources
Domains Vision, NLP, Multimodal
Last Updated 2026-02-15 00:00 GMT

Overview

Implements the concrete LLaVA-LLaMA multimodal model by combining the LLaVA vision-language mixin with a LLaMA causal language model backbone.

Description

This module defines two classes. LlavaLlamaModel is a simple composition of LlavaMetaModel and Transformer, inheriting both the vision tower/projector initialization from LLaVA and the LLaMA transformer architecture.

LlavaLlamaForCausalLM inherits from LlamaForCausalLM and LlavaMetaForCausalLM, creating a full multimodal causal language model. It initializes a LlavaLlamaModel backbone, a language model head, and exposes the forward method that handles two distinct multimodal embedding strategies: (1) a special-token-based approach via prepare_inputs_labels_for_multimodal inherited from the base LLaVA arch, and (2) a default patch-token-based approach via default_inputs_embeds_for_multimodal. The default method runs the vision tower (with CLIP support), applies the mm_projector, optionally adds NEFTune noise, and replaces image patch token positions in the embedding sequence with projected image features. The forward method dispatches between these strategies based on the special_token flag and supports both full-sequence and incremental (start_pos) decoding modes. prepare_inputs_for_generation wraps the parent class method to propagate images through the generation pipeline.

Usage

Import LlavaLlamaForCausalLM when you need a ready-to-use LLaVA model backed by a LLaMA decoder for inference or fine-tuning with image-text inputs.

Code Reference

Source Location

Signature

class LlavaLlamaModel(LlavaMetaModel, Transformer):
    def __init__(self, config): ...

class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
    def __init__(self, config, dev="cuda"): ...
    def get_model(self) -> LlavaLlamaModel: ...
    def default_inputs_embeds_for_multimodal(
        self, input_ids, inputs_embeds, images
    ) -> torch.Tensor: ...
    def forward(
        self, input_ids=None, start_pos=None, attention_mask=None,
        position_ids=None, past_key_values=None, inputs_embeds=None,
        labels=None, use_cache=None, output_attentions=None,
        output_hidden_states=None, images=None, return_dict=None,
        special_token=False
    ) -> Union[Tuple, CausalLMOutputWithPast]: ...
    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
    ) -> dict: ...

Import

from tinychat.models.llava_llama import LlavaLlamaForCausalLM

I/O Contract

Inputs

Name Type Required Description
input_ids torch.LongTensor No Token IDs for the input; may contain image patch token placeholders
start_pos int No Starting position for incremental decoding; None for full-sequence mode
attention_mask torch.Tensor No Attention mask for the sequence
position_ids torch.LongTensor No Position IDs for rotary embeddings
past_key_values List[torch.FloatTensor] No KV cache from previous generation steps
inputs_embeds torch.FloatTensor No Pre-computed embeddings; bypasses multimodal embedding if provided
labels torch.LongTensor No Target labels for training loss computation
images torch.FloatTensor No Image tensor(s) to encode; shape (B, C, H, W)
special_token bool No If True, uses special-token strategy; otherwise uses default patch-token replacement

Outputs

Name Type Description
output Union[Tuple, CausalLMOutputWithPast] Model output containing logits and optionally loss, hidden states, and attentions

Usage Examples

Forward pass with images

from tinychat.models.llava_llama import LlavaLlamaForCausalLM

model = LlavaLlamaForCausalLM(config, dev="cuda")
output = model.forward(
    input_ids=input_ids,
    images=image_tensor,
    attention_mask=attention_mask,
    labels=labels,
)
loss = output.loss

Incremental decoding with start_pos

output = model.forward(
    input_ids=next_token_ids,
    start_pos=current_position,
    images=None,  # Images already processed in prefill
)
logits = output

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment