Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mit han lab Llm awq NVILA LlavaArch

From Leeroopedia
Knowledge Sources
Domains Vision, Multimodal, Model_Architecture
Last Updated 2026-02-15 00:00 GMT

Overview

Defines the NVILA-specific multimodal architecture base classes with advanced features including dynamic S2 multi-scale vision encoding, media token embedding fusion, streaming generation, and checkpoint save/load support.

Description

This module provides two foundational abstract base classes for the NVILA multimodal model family, significantly extending the original LLaVA architecture.

LlavaMetaModel(ABC) manages model initialization and component lifecycle. The init_vlm method builds the tokenizer (via build_tokenizer), vision tower, multimodal projector, and Hydra-instantiated image/video encoders. load_pretrained is a classmethod that constructs a VLM from a pretrained path or config, using no_init_weights context for efficient loading, then populating LLM, vision tower, and projector from pretrained weights. save_pretrained serializes each component (LLM, vision tower, projector) to separate subdirectories with proper config updates.

The encode_images method supports two paths: standard encoding (vision tower + mm_projector) and dynamic S2 encoding. Dynamic S2 processes images at multiple scales by splitting them into chessboard sub-patches, encoding each sub-patch through the vision tower, merging features back via merge_chessboard, resizing to a target scale via interpolation, concatenating across scale channels, projecting, and then splitting/merging again to produce the final feature sequence. Helper methods split_chessboard and merge_chessboard handle the spatial decomposition and recombination using einops rearrange operations. merge_features_for_dynamic_s2 orchestrates the per-image multi-scale feature assembly.

LlavaMetaForCausalLM(ABC) provides the inference and generation interface. The _embed method fuses text and media embeddings by: (1) extracting text embeddings via the LLM's embed_tokens, (2) encoding media through registered encoders, (3) iterating through token sequences to replace media token positions with their corresponding embeddings while marking labels as IGNORE_INDEX, (4) truncating to model_max_length, and (5) padding/batching sequences. Private helper __embed_media_tokens dispatches to the appropriate encoder for each media type, __truncate_sequence enforces length limits, and __batchify_sequence handles padding with left/right alignment.

The generate method performs inference-mode generation by embedding inputs and delegating to the LLM's generate. generate_content provides a higher-level interface that takes a text prompt, extracts and processes media (images/videos), handles dynamic resolution and S2 processing, tokenizes the conversation, and decodes the output. benchmark runs timed profiling of the vision encoding and LLM inference. prepare_media pre-processes media from conversations. stream_gen enables streaming generation with chunk prefilling support, handling dynamic resolution image token expansion for the Qwen tokenizer.

Usage

Import LlavaMetaModel and LlavaMetaForCausalLM as base classes when building NVILA-family multimodal models. They provide the complete vision-language integration layer that concrete model classes (e.g., NVILAQwen2) inherit from.

Code Reference

Source Location

Signature

class LlavaMetaModel(ABC):
    def init_vlm(self, config, *args, **kwargs): ...
    @classmethod
    def load_pretrained(cls, model_path_or_config, *args, **kwargs): ...
    def save_pretrained(self, output_dir, state_dict=None): ...
    def get_llm(self) -> PreTrainedModel: ...
    def get_lm_head(self) -> nn.Module: ...
    def get_vision_tower(self) -> nn.Module: ...
    def get_mm_projector(self) -> nn.Module: ...
    def post_config(self): ...
    @staticmethod
    def split_chessboard(x, num_split_h, num_split_w) -> torch.Tensor: ...
    @staticmethod
    def merge_chessboard(x, num_split_h, num_split_w) -> torch.Tensor: ...
    def merge_features_for_dynamic_s2(self, image_features, block_sizes): ...
    def encode_images(self, images, block_sizes=None) -> torch.Tensor: ...

class LlavaMetaForCausalLM(ABC):
    def _embed(self, input_ids, media, media_config, labels, attention_mask)
        -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ...
    def generate(self, input_ids=None, media=None, media_config=None,
        attention_mask=None, quant_llm=True, **generation_kwargs): ...
    def generate_content(self, prompt, generation_config=None,
        quant_llm=True) -> str: ...
    def benchmark(self, prompt, quant_llm) -> None: ...
    def prepare_media(self, conversation)
        -> Tuple[dict, dict]: ...
    def stream_gen(self, input_ids, media, media_cfg, start_pos,
        chunk_prefilling, quant_llm, attention_mask=None) -> Tuple: ...

Import

from tinychat.models.nvila.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM

I/O Contract

Inputs

Name Type Required Description
input_ids torch.Tensor Yes Token IDs including media placeholder tokens
media Dict[str, List[torch.Tensor]] No Dictionary mapping media type names ("image", "video") to lists of tensors
media_config Dict[str, Dict[str, Any]] No Per-media-type configuration (e.g., block_sizes for dynamic S2)
labels torch.Tensor No Target labels for training; image positions filled with IGNORE_INDEX
attention_mask torch.Tensor No Attention mask for the input sequence
images torch.Tensor Yes (encode_images) Image tensors to encode through the vision tower
block_sizes List[Optional[Tuple[int, ...]]] No Per-image tile block sizes for dynamic S2 processing
prompt Union[str, List] Yes (generate_content) Text prompt optionally containing media tokens
quant_llm bool No Whether to use quantized LLM forward pass (default True)

Outputs

Name Type Description
inputs_embeds torch.Tensor Fused text and media embeddings for the language model
labels torch.Tensor Aligned labels with IGNORE_INDEX at media positions
attention_mask torch.Tensor Updated attention mask after padding
image_features torch.Tensor Encoded and projected image features from encode_images
response str Decoded text response from generate_content

Usage Examples

Embedding multimodal inputs

# Within a model inheriting LlavaMetaForCausalLM
inputs_embeds, labels, attention_mask = self._embed(
    input_ids=input_ids,
    media={"image": [img_tensor]},
    media_config={"image": {"block_sizes": [(2, 2)]}},
    labels=labels,
    attention_mask=attention_mask,
)

Generating content from a prompt

response = model.generate_content(
    prompt="<image>\nDescribe this image in detail.",
    generation_config=GenerationConfig(max_new_tokens=256),
    quant_llm=True,
)
print(response)

Streaming generation

output, length = model.stream_gen(
    input_ids=input_ids,
    media=media,
    media_cfg=media_config,
    start_pos=0,
    chunk_prefilling=True,
    quant_llm=True,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment