Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:OpenGVLab InternVL QLLaMA Model

From Leeroopedia


Knowledge Sources
Domains Language Modeling, Cross-Attention, Vision-Language Model
Last Updated 2026-02-07 14:00 GMT

Overview

This module implements QLLaMA, a modified LLaMA model with interleaved cross-attention layers that enable query tokens to attend to visual features from a vision encoder, serving as the language component of the InternVL-14B vision-language bridge.

Description

The QLLaMA implementation extends the standard LLaMA architecture with cross-attention capabilities for multimodal processing:

Core components:

  • LlamaRMSNorm: Root Mean Square normalization with automatic fallback to apex FusedRMSNorm when available.
  • FixedLlamaRotaryEmbedding (aliased as LlamaRotaryEmbedding): Position encoding using rotary embeddings with float32 precision for numerical stability, computed via outer product rather than einsum.
  • LlamaMLP: Standard gated MLP with gate_proj, up_proj, and down_proj linear layers.
  • LlamaAttention: Standard multi-head self-attention with rotary position embeddings, supporting KV caching and attention mask handling.
  • LlamaCrossAttention: The key innovation -- cross-attention where queries come from text/query tokens (projected through q_proj with RMSNorm), and keys/values come from vision hidden states (projected through k_proj/v_proj with separate RMSNorm). The vision hidden dimension is hardcoded at 3200 (matching InternViT-6B). Supports repeat_time for broadcasting vision features across batch elements.

Architecture:

  • LlamaDecoderLayer: Combines self-attention, optional cross-attention (controlled by use_cross_attn flag), and MLP with residual connections. Cross-attention is applied only when the sequence length is at least num_query_token (96) and vision_hidden_states is provided, enabling cache-mode generation to skip cross-attention.
  • LlamaModel: Stacks decoder layers with cross-attention at regular intervals (every cross_attention_frequency layers). Provides both forward (standard with causal mask preparation) and forward_train (for training with pre-computed attention masks) methods. The forward method supports use_zero_attention_mask to zero out attention within query token positions.
  • LlamaForCausalLM: Adds a linear language model head for next-token prediction with cross-entropy loss.

Usage

Use QLLaMA as the language model component within the InternVL-14B architecture, where it processes learnable query tokens that attend to InternViT visual features through cross-attention layers to produce vision-language aligned representations.

Code Reference

Source Location

Signature

class LlamaModel(LlamaPreTrainedModel):
    def __init__(self, config: LlamaConfig):
        ...
    def forward(self, input_ids=None, attention_mask=None, position_ids=None,
                past_key_values=None, inputs_embeds=None,
                vision_hidden_states=None, repeat_time=1,
                use_cache=None, output_attentions=None,
                output_hidden_states=None, use_zero_attention_mask=None,
                return_dict=None):
        ...
    def forward_train(self, input_ids=None, attention_mask=None,
                      position_ids=None, past_key_values=None,
                      inputs_embeds=None, vision_hidden_states=None,
                      repeat_time=1, ...):
        ...

class LlamaForCausalLM(LlamaPreTrainedModel):
    def __init__(self, config):
        ...
    def forward(self, input_ids=None, attention_mask=None, position_ids=None,
                past_key_values=None, inputs_embeds=None,
                vision_hidden_states=None, labels=None, ...):
        ...

Import

from internvl_chat_llava.llava.model.multimodal_encoder.internvl_14b.modeling_qllama import (
    LlamaForCausalLM,
    LlamaModel,
    LlamaCrossAttention,
    _expand_mask,
    _make_causal_mask,
)

I/O Contract

Inputs

Name Type Required Description
input_ids torch.LongTensor [batch, seq_len] Yes (or inputs_embeds) Token IDs for the language model
inputs_embeds torch.FloatTensor [batch, seq_len, hidden_size] No Pre-computed input embeddings (alternative to input_ids)
vision_hidden_states torch.FloatTensor [batch, vis_seq_len, 3200] No Visual features from InternViT for cross-attention
attention_mask torch.Tensor [batch, seq_len] No Attention mask for text tokens
position_ids torch.LongTensor [batch, seq_len] No Position IDs for rotary embeddings
labels torch.LongTensor [batch, seq_len] No Target labels for language modeling loss
past_key_values List[Tuple[torch.Tensor]] No Cached KV states for generation
use_zero_attention_mask bool No Whether to zero out attention among query tokens
repeat_time int No How many times to repeat vision features (default 1)

Outputs

Name Type Description
last_hidden_state torch.FloatTensor [batch, seq_len, hidden_size] Hidden states from the last decoder layer
logits torch.FloatTensor [batch, seq_len, vocab_size] Language model prediction logits (LlamaForCausalLM)
loss torch.FloatTensor Cross-entropy loss when labels are provided
past_key_values tuple Cached key-value states for autoregressive generation

Usage Examples

Basic Usage

from internvl_chat_llava.llava.model.multimodal_encoder.internvl_14b.modeling_qllama import (
    LlamaForCausalLM
)
from transformers import LlamaConfig

# Configure QLLaMA with cross-attention
config = LlamaConfig(
    hidden_size=4096, num_attention_heads=32, num_hidden_layers=32,
    intermediate_size=11008, cross_attention_frequency=4,
    num_query_token=96,
)
model = LlamaForCausalLM(config)

# Forward with vision features (cross-attention enabled)
outputs = model(
    inputs_embeds=query_and_text_embeds,
    vision_hidden_states=vision_features,  # from InternViT
    attention_mask=attention_mask,
    use_zero_attention_mask=True,
)

# Training with pre-computed attention mask
hidden_states = model.model.forward_train(
    inputs_embeds=embeds,
    vision_hidden_states=vision_features,
    attention_mask=precomputed_mask,
).last_hidden_state

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment