Implementation:Mit han lab Llm awq VilaLlamaForCausalLM
| Knowledge Sources | |
|---|---|
| Domains | Vision, NLP, Multimodal |
| Last Updated | 2026-02-15 00:00 GMT |
Overview
Implements the VILA (Visual Language Architecture) variant of LLaVA using a LLaMA backbone, supporting vision-language model initialization and chunk-prefilling inference.
Description
VilaLlamaForCausalLM combines LlavaMetaModel, LlavaMetaForCausalLM, and PreTrainedModel into a single class that provides a complete vision-language model built on the VILA architecture. Unlike the simpler LlavaLlamaForCausalLM, this class uses the VILA-style component initialization pattern from the llava library.
The init_vlm method initializes the three core components: (1) the LLM backbone via LlamaForCausalLM, (2) the vision tower via VILA's build_vision_tower, and (3) the multimodal projector via VILA's build_mm_projector. It extracts sub-configurations from the main config using get_model_config, handles model dtype defaults (torch.float16), and performs post-configuration. An idempotency guard prevents re-initialization if components already exist.
The forward method first calls freezed_module_patch (inherited from the VILA base), then uses prepare_inputs_labels_for_multimodal to fuse image features into the input embeddings. It dispatches to the LLM backbone with support for chunk_prefilling (processing long prefill sequences in chunks for memory efficiency) and start_pos for incremental decoding. When inputs_embeds are available, token IDs are set to None; otherwise, raw token IDs are forwarded directly.
Usage
Import VilaLlamaForCausalLM when deploying a VILA-based multimodal model with LLaMA as the language backbone, particularly in TinyChat inference scenarios that require chunk prefilling.
Code Reference
Source Location
- Repository: Mit_han_lab_Llm_awq
- File: tinychat/models/vila_llama.py
- Lines: 1-109
Signature
class VilaLlamaForCausalLM(LlavaMetaModel, LlavaMetaForCausalLM, PreTrainedModel):
def __init__(self, config): ...
def init_vlm(self, config=None, *args, **kwargs): ...
def forward(
self, input_ids=None, start_pos=None, attention_mask=None,
position_ids=None, past_key_values=None, inputs_embeds=None,
labels=None, use_cache=None, output_attentions=None,
output_hidden_states=None, images=None, return_dict=None,
special_token=False, chunk_prefilling=False
) -> Union[Tuple, CausalLMOutputWithPast]: ...
Import
from tinychat.models.vila_llama import VilaLlamaForCausalLM
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.LongTensor | No | Token IDs for the input sequence |
| start_pos | int | No | Starting position for incremental decoding |
| attention_mask | torch.Tensor | No | Attention mask for the input |
| position_ids | torch.LongTensor | No | Position IDs for rotary embeddings |
| past_key_values | List[torch.FloatTensor] | No | KV cache from previous steps |
| inputs_embeds | torch.FloatTensor | No | Pre-computed input embeddings; skips multimodal fusion if provided |
| labels | torch.LongTensor | No | Target labels for training |
| images | torch.FloatTensor | No | Image tensors to encode via the vision tower |
| chunk_prefilling | bool | No | Enable chunk-based prefilling for long sequences (default False) |
Outputs
| Name | Type | Description |
|---|---|---|
| outputs | Union[Tuple, CausalLMOutputWithPast] | Language model output with logits and optionally loss |
Usage Examples
Initializing and running VILA model
from tinychat.models.vila_llama import VilaLlamaForCausalLM
model = VilaLlamaForCausalLM(config)
# Forward with images and chunk prefilling
output = model.forward(
input_ids=input_ids,
start_pos=0,
images=image_tensor,
chunk_prefilling=True,
)
Incremental decoding after prefill
# After initial prefill, decode token by token
output = model.forward(
input_ids=next_token,
start_pos=current_pos,
)