Implementation:Mit han lab Llm awq NVILAQwen2
| Knowledge Sources | |
|---|---|
| Domains | NLP, Multimodal |
| Last Updated | 2026-02-15 00:00 GMT |
Overview
Implements the NVILA multimodal model using Qwen2 as the language model backbone, combining NVILA's vision-language architecture with Qwen2's causal language model.
Description
This module defines LlavaLlamaConfig (extending LlavaConfig with model_type "llava_llama") and the main NVILAQwen2 class, which inherits from LlavaMetaModel, LlavaMetaForCausalLM, and PreTrainedModel.
NVILAQwen2 initializes by calling init_vlm (from LlavaMetaModel) to set up the vision tower, projector, tokenizer, and encoders, then loads a Qwen2ForCausalLM as the language backbone. The LLM is loaded from pretrained weights, moved to CPU for initial setup, and has its token embeddings resized to accommodate media tokens added by the tokenizer. The llm parameter in the constructor controls whether the LLM is actually loaded (useful for vision-only configurations). The module applies weight initialization skipping at import time by patching torch.nn.init functions and setting modeling_utils._init_weights to False for faster model construction.
from_pretrained overrides the standard HuggingFace loading to delegate to load_pretrained (from LlavaMetaModel) when available, which handles the multi-component loading pattern (LLM, vision tower, projector each from separate subdirectories).
The forward method implements the complete multimodal forward pass. It accepts both media (dict-based, for NVILA-style inputs) and images (tensor-based, for backward compatibility), converting images to media format. It calls _embed to fuse text and media embeddings, then passes the result through the Qwen2 LLM with standard causal LM arguments. When dpo_forward is True, it returns raw logits and labels for Direct Preference Optimization training.
Usage
Import NVILAQwen2 as the primary model class for NVILA deployments that use Qwen2 as the language backbone. Typically loaded via from_pretrained from a checkpoint directory.
Code Reference
Source Location
- Repository: Mit_han_lab_Llm_awq
- File: tinychat/models/nvila_qwen2.py
- Lines: 1-157
Signature
class LlavaLlamaConfig(LlavaConfig):
model_type = "llava_llama"
class NVILAQwen2(LlavaMetaModel, LlavaMetaForCausalLM, PreTrainedModel):
config_class = LlavaLlamaConfig
main_input_name = "input_embeds"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
def __init__(self, config: LlavaLlamaConfig = None, llm=True,
*args, **kwargs) -> None: ...
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
config=None, **kwargs): ...
def forward(
self, input_ids=None, media=None, images=None,
media_config=None, attention_mask=None, position_ids=None,
past_key_values=None, inputs_embeds=None, labels=None,
packing=True, seqlens_in_batch=None, dpo_forward=False,
**kwargs
) -> Union[Tuple, CausalLMOutputWithPast]: ...
Import
from tinychat.models.nvila_qwen2 import NVILAQwen2, LlavaLlamaConfig
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.LongTensor | No | Token IDs for the input sequence |
| media | Dict[str, List[torch.Tensor]] | No | Dictionary mapping media types to tensor lists (e.g., {"image": [tensor]}) |
| images | torch.FloatTensor | No | Image tensors (backward compatible); converted to media dict internally |
| media_config | Dict[str, Dict] | No | Per-media configuration (e.g., block_sizes for S2) |
| attention_mask | torch.Tensor | No | Attention mask |
| position_ids | torch.LongTensor | No | Position IDs for the sequence |
| past_key_values | List[torch.FloatTensor] | No | KV cache for incremental generation |
| inputs_embeds | torch.FloatTensor | No | Pre-computed embeddings; skips _embed if provided |
| labels | torch.LongTensor | No | Target labels for loss computation |
| dpo_forward | bool | No | If True, returns (logits, labels) for DPO training |
Outputs
| Name | Type | Description |
|---|---|---|
| outputs | Union[Tuple, CausalLMOutputWithPast] | Standard causal LM output with loss, logits, hidden states; or (logits, labels) if dpo_forward=True |
Usage Examples
Loading and running inference
from tinychat.models.nvila_qwen2 import NVILAQwen2
model = NVILAQwen2.from_pretrained("path/to/nvila-qwen2-checkpoint")
# Generate a response
response = model.generate_content(
prompt="<image>\nWhat is shown in this image?",
quant_llm=True,
)
Forward pass with media dict
outputs = model.forward(
input_ids=input_ids,
media={"image": [image_tensor]},
labels=labels,
attention_mask=attention_mask,
)
loss = outputs.loss
DPO training forward
logits, labels = model.forward(
input_ids=input_ids,
media={"image": [image_tensor]},
labels=labels,
dpo_forward=True,
)