Implementation:Mit han lab Llm awq NVILA LlavaArch
| Knowledge Sources | |
|---|---|
| Domains | Vision, Multimodal, Model_Architecture |
| Last Updated | 2026-02-15 00:00 GMT |
Overview
Defines the NVILA-specific multimodal architecture base classes with advanced features including dynamic S2 multi-scale vision encoding, media token embedding fusion, streaming generation, and checkpoint save/load support.
Description
This module provides two foundational abstract base classes for the NVILA multimodal model family, significantly extending the original LLaVA architecture.
LlavaMetaModel(ABC) manages model initialization and component lifecycle. The init_vlm method builds the tokenizer (via build_tokenizer), vision tower, multimodal projector, and Hydra-instantiated image/video encoders. load_pretrained is a classmethod that constructs a VLM from a pretrained path or config, using no_init_weights context for efficient loading, then populating LLM, vision tower, and projector from pretrained weights. save_pretrained serializes each component (LLM, vision tower, projector) to separate subdirectories with proper config updates.
The encode_images method supports two paths: standard encoding (vision tower + mm_projector) and dynamic S2 encoding. Dynamic S2 processes images at multiple scales by splitting them into chessboard sub-patches, encoding each sub-patch through the vision tower, merging features back via merge_chessboard, resizing to a target scale via interpolation, concatenating across scale channels, projecting, and then splitting/merging again to produce the final feature sequence. Helper methods split_chessboard and merge_chessboard handle the spatial decomposition and recombination using einops rearrange operations. merge_features_for_dynamic_s2 orchestrates the per-image multi-scale feature assembly.
LlavaMetaForCausalLM(ABC) provides the inference and generation interface. The _embed method fuses text and media embeddings by: (1) extracting text embeddings via the LLM's embed_tokens, (2) encoding media through registered encoders, (3) iterating through token sequences to replace media token positions with their corresponding embeddings while marking labels as IGNORE_INDEX, (4) truncating to model_max_length, and (5) padding/batching sequences. Private helper __embed_media_tokens dispatches to the appropriate encoder for each media type, __truncate_sequence enforces length limits, and __batchify_sequence handles padding with left/right alignment.
The generate method performs inference-mode generation by embedding inputs and delegating to the LLM's generate. generate_content provides a higher-level interface that takes a text prompt, extracts and processes media (images/videos), handles dynamic resolution and S2 processing, tokenizes the conversation, and decodes the output. benchmark runs timed profiling of the vision encoding and LLM inference. prepare_media pre-processes media from conversations. stream_gen enables streaming generation with chunk prefilling support, handling dynamic resolution image token expansion for the Qwen tokenizer.
Usage
Import LlavaMetaModel and LlavaMetaForCausalLM as base classes when building NVILA-family multimodal models. They provide the complete vision-language integration layer that concrete model classes (e.g., NVILAQwen2) inherit from.
Code Reference
Source Location
- Repository: Mit_han_lab_Llm_awq
- File: tinychat/models/nvila/llava_arch.py
- Lines: 1-909
Signature
class LlavaMetaModel(ABC):
def init_vlm(self, config, *args, **kwargs): ...
@classmethod
def load_pretrained(cls, model_path_or_config, *args, **kwargs): ...
def save_pretrained(self, output_dir, state_dict=None): ...
def get_llm(self) -> PreTrainedModel: ...
def get_lm_head(self) -> nn.Module: ...
def get_vision_tower(self) -> nn.Module: ...
def get_mm_projector(self) -> nn.Module: ...
def post_config(self): ...
@staticmethod
def split_chessboard(x, num_split_h, num_split_w) -> torch.Tensor: ...
@staticmethod
def merge_chessboard(x, num_split_h, num_split_w) -> torch.Tensor: ...
def merge_features_for_dynamic_s2(self, image_features, block_sizes): ...
def encode_images(self, images, block_sizes=None) -> torch.Tensor: ...
class LlavaMetaForCausalLM(ABC):
def _embed(self, input_ids, media, media_config, labels, attention_mask)
-> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ...
def generate(self, input_ids=None, media=None, media_config=None,
attention_mask=None, quant_llm=True, **generation_kwargs): ...
def generate_content(self, prompt, generation_config=None,
quant_llm=True) -> str: ...
def benchmark(self, prompt, quant_llm) -> None: ...
def prepare_media(self, conversation)
-> Tuple[dict, dict]: ...
def stream_gen(self, input_ids, media, media_cfg, start_pos,
chunk_prefilling, quant_llm, attention_mask=None) -> Tuple: ...
Import
from tinychat.models.nvila.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.Tensor | Yes | Token IDs including media placeholder tokens |
| media | Dict[str, List[torch.Tensor]] | No | Dictionary mapping media type names ("image", "video") to lists of tensors |
| media_config | Dict[str, Dict[str, Any]] | No | Per-media-type configuration (e.g., block_sizes for dynamic S2) |
| labels | torch.Tensor | No | Target labels for training; image positions filled with IGNORE_INDEX |
| attention_mask | torch.Tensor | No | Attention mask for the input sequence |
| images | torch.Tensor | Yes (encode_images) | Image tensors to encode through the vision tower |
| block_sizes | List[Optional[Tuple[int, ...]]] | No | Per-image tile block sizes for dynamic S2 processing |
| prompt | Union[str, List] | Yes (generate_content) | Text prompt optionally containing media tokens |
| quant_llm | bool | No | Whether to use quantized LLM forward pass (default True) |
Outputs
| Name | Type | Description |
|---|---|---|
| inputs_embeds | torch.Tensor | Fused text and media embeddings for the language model |
| labels | torch.Tensor | Aligned labels with IGNORE_INDEX at media positions |
| attention_mask | torch.Tensor | Updated attention mask after padding |
| image_features | torch.Tensor | Encoded and projected image features from encode_images |
| response | str | Decoded text response from generate_content |
Usage Examples
Embedding multimodal inputs
# Within a model inheriting LlavaMetaForCausalLM
inputs_embeds, labels, attention_mask = self._embed(
input_ids=input_ids,
media={"image": [img_tensor]},
media_config={"image": {"block_sizes": [(2, 2)]}},
labels=labels,
attention_mask=attention_mask,
)
Generating content from a prompt
response = model.generate_content(
prompt="<image>\nDescribe this image in detail.",
generation_config=GenerationConfig(max_new_tokens=256),
quant_llm=True,
)
print(response)
Streaming generation
output, length = model.stream_gen(
input_ids=input_ids,
media=media,
media_cfg=media_config,
start_pos=0,
chunk_prefilling=True,
quant_llm=True,
)