Implementation:Mit han lab Llm awq LlavaMetaForCausalLM
| Knowledge Sources | |
|---|---|
| Domains | Vision, Multimodal |
| Last Updated | 2026-02-15 00:00 GMT |
Overview
Provides the base abstract classes LlavaMetaModel and LlavaMetaForCausalLM that define the vision-language integration layer for the LLaVA multimodal architecture.
Description
This module defines two foundational classes for the LLaVA multimodal framework. LlavaMetaModel serves as a mixin that initializes and manages the vision tower (CLIP encoder) and the multimodal projector (mm_projector) that bridges vision and language feature spaces. It provides methods to build the vision tower from configuration, initialize vision modules with optional FSDP support, and load pretrained MLP adapter weights.
LlavaMetaForCausalLM is an abstract base class (ABC) that provides the core multimodal inference logic. Its central method, prepare_inputs_labels_for_multimodal, takes standard causal LM inputs along with images and replaces image token placeholders in the input sequence with encoded image feature embeddings. It handles batched inputs, padding (both left and right), label alignment with IGNORE_INDEX for image tokens, and sequence truncation to respect model max length. The encode_images method runs images through the vision tower and projector to produce language-space features. The initialize_vision_tokenizer method adds special image tokens (patch token, start/end tokens) to the tokenizer and resizes embeddings accordingly.
Usage
Import LlavaMetaModel as a base class when building a model that needs vision tower and projector initialization. Import LlavaMetaForCausalLM as a mixin for any causal language model that should support multimodal image-text inputs.
Code Reference
Source Location
- Repository: Mit_han_lab_Llm_awq
- File: tinychat/models/llava_base/llava_arch.py
- Lines: 1-412
Signature
class LlavaMetaModel:
def __init__(self, config): ...
def get_vision_tower(self) -> nn.Module: ...
def initialize_vision_modules(self, model_args, fsdp=None): ...
class LlavaMetaForCausalLM(ABC):
@abstractmethod
def get_model(self): ...
def get_vision_tower(self): ...
def encode_images(self, images) -> torch.Tensor: ...
def prepare_inputs_labels_for_multimodal(
self, input_ids, position_ids, attention_mask,
past_key_values, labels, images
) -> Tuple: ...
def initialize_vision_tokenizer(self, model_args, tokenizer): ...
Import
from tinychat.models.llava_base.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.LongTensor | Yes | Token IDs for the input sequence, including image placeholder tokens |
| position_ids | torch.LongTensor | No | Position IDs for each token; auto-generated if None |
| attention_mask | torch.Tensor | No | Attention mask; auto-generated as all-ones if None |
| past_key_values | List[torch.FloatTensor] | No | Cached key/value pairs from previous decoding steps |
| labels | torch.LongTensor | No | Target labels for loss computation; image positions filled with IGNORE_INDEX |
| images | torch.FloatTensor | Yes | Image tensor(s) of shape (B, C, H, W) or list of image tensors |
Outputs
| Name | Type | Description |
|---|---|---|
| input_ids | None | Set to None when inputs_embeds are produced |
| position_ids | torch.LongTensor | Updated position IDs aligned with new sequence length |
| attention_mask | torch.Tensor | Updated attention mask reflecting image embedding insertion and padding |
| past_key_values | List[torch.FloatTensor] | Passed through unchanged from input |
| inputs_embeds | torch.FloatTensor | Fused text and image embeddings ready for the language model |
| labels | torch.LongTensor | Labels with IGNORE_INDEX at image token positions and padding |
Usage Examples
Encoding images and preparing multimodal inputs
# Within a model that inherits LlavaMetaForCausalLM
image_features = self.encode_images(images)
(
input_ids, position_ids, attention_mask,
past_key_values, inputs_embeds, labels
) = self.prepare_inputs_labels_for_multimodal(
input_ids, position_ids, attention_mask,
past_key_values, labels, images
)
# Pass inputs_embeds to the language model decoder
Initializing vision modules
# During model setup
model.get_model().initialize_vision_modules(model_args, fsdp=fsdp_config)
model.initialize_vision_tokenizer(model_args, tokenizer)