Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mit han lab Llm awq NVILAQwen2

From Leeroopedia
Knowledge Sources
Domains NLP, Multimodal
Last Updated 2026-02-15 00:00 GMT

Overview

Implements the NVILA multimodal model using Qwen2 as the language model backbone, combining NVILA's vision-language architecture with Qwen2's causal language model.

Description

This module defines LlavaLlamaConfig (extending LlavaConfig with model_type "llava_llama") and the main NVILAQwen2 class, which inherits from LlavaMetaModel, LlavaMetaForCausalLM, and PreTrainedModel.

NVILAQwen2 initializes by calling init_vlm (from LlavaMetaModel) to set up the vision tower, projector, tokenizer, and encoders, then loads a Qwen2ForCausalLM as the language backbone. The LLM is loaded from pretrained weights, moved to CPU for initial setup, and has its token embeddings resized to accommodate media tokens added by the tokenizer. The llm parameter in the constructor controls whether the LLM is actually loaded (useful for vision-only configurations). The module applies weight initialization skipping at import time by patching torch.nn.init functions and setting modeling_utils._init_weights to False for faster model construction.

from_pretrained overrides the standard HuggingFace loading to delegate to load_pretrained (from LlavaMetaModel) when available, which handles the multi-component loading pattern (LLM, vision tower, projector each from separate subdirectories).

The forward method implements the complete multimodal forward pass. It accepts both media (dict-based, for NVILA-style inputs) and images (tensor-based, for backward compatibility), converting images to media format. It calls _embed to fuse text and media embeddings, then passes the result through the Qwen2 LLM with standard causal LM arguments. When dpo_forward is True, it returns raw logits and labels for Direct Preference Optimization training.

Usage

Import NVILAQwen2 as the primary model class for NVILA deployments that use Qwen2 as the language backbone. Typically loaded via from_pretrained from a checkpoint directory.

Code Reference

Source Location

Signature

class LlavaLlamaConfig(LlavaConfig):
    model_type = "llava_llama"

class NVILAQwen2(LlavaMetaModel, LlavaMetaForCausalLM, PreTrainedModel):
    config_class = LlavaLlamaConfig
    main_input_name = "input_embeds"
    supports_gradient_checkpointing = True
    _supports_flash_attn_2 = True

    def __init__(self, config: LlavaLlamaConfig = None, llm=True,
                 *args, **kwargs) -> None: ...
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
                        config=None, **kwargs): ...
    def forward(
        self, input_ids=None, media=None, images=None,
        media_config=None, attention_mask=None, position_ids=None,
        past_key_values=None, inputs_embeds=None, labels=None,
        packing=True, seqlens_in_batch=None, dpo_forward=False,
        **kwargs
    ) -> Union[Tuple, CausalLMOutputWithPast]: ...

Import

from tinychat.models.nvila_qwen2 import NVILAQwen2, LlavaLlamaConfig

I/O Contract

Inputs

Name Type Required Description
input_ids torch.LongTensor No Token IDs for the input sequence
media Dict[str, List[torch.Tensor]] No Dictionary mapping media types to tensor lists (e.g., {"image": [tensor]})
images torch.FloatTensor No Image tensors (backward compatible); converted to media dict internally
media_config Dict[str, Dict] No Per-media configuration (e.g., block_sizes for S2)
attention_mask torch.Tensor No Attention mask
position_ids torch.LongTensor No Position IDs for the sequence
past_key_values List[torch.FloatTensor] No KV cache for incremental generation
inputs_embeds torch.FloatTensor No Pre-computed embeddings; skips _embed if provided
labels torch.LongTensor No Target labels for loss computation
dpo_forward bool No If True, returns (logits, labels) for DPO training

Outputs

Name Type Description
outputs Union[Tuple, CausalLMOutputWithPast] Standard causal LM output with loss, logits, hidden states; or (logits, labels) if dpo_forward=True

Usage Examples

Loading and running inference

from tinychat.models.nvila_qwen2 import NVILAQwen2

model = NVILAQwen2.from_pretrained("path/to/nvila-qwen2-checkpoint")

# Generate a response
response = model.generate_content(
    prompt="<image>\nWhat is shown in this image?",
    quant_llm=True,
)

Forward pass with media dict

outputs = model.forward(
    input_ids=input_ids,
    media={"image": [image_tensor]},
    labels=labels,
    attention_mask=attention_mask,
)
loss = outputs.loss

DPO training forward

logits, labels = model.forward(
    input_ids=input_ids,
    media={"image": [image_tensor]},
    labels=labels,
    dpo_forward=True,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment