Implementation:OpenGVLab InternVL QLLaMA Model
| Knowledge Sources | |
|---|---|
| Domains | Language Modeling, Cross-Attention, Vision-Language Model |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
This module implements QLLaMA, a modified LLaMA model with interleaved cross-attention layers that enable query tokens to attend to visual features from a vision encoder, serving as the language component of the InternVL-14B vision-language bridge.
Description
The QLLaMA implementation extends the standard LLaMA architecture with cross-attention capabilities for multimodal processing:
Core components:
- LlamaRMSNorm: Root Mean Square normalization with automatic fallback to apex FusedRMSNorm when available.
- FixedLlamaRotaryEmbedding (aliased as LlamaRotaryEmbedding): Position encoding using rotary embeddings with float32 precision for numerical stability, computed via outer product rather than einsum.
- LlamaMLP: Standard gated MLP with gate_proj, up_proj, and down_proj linear layers.
- LlamaAttention: Standard multi-head self-attention with rotary position embeddings, supporting KV caching and attention mask handling.
- LlamaCrossAttention: The key innovation -- cross-attention where queries come from text/query tokens (projected through q_proj with RMSNorm), and keys/values come from vision hidden states (projected through k_proj/v_proj with separate RMSNorm). The vision hidden dimension is hardcoded at 3200 (matching InternViT-6B). Supports repeat_time for broadcasting vision features across batch elements.
Architecture:
- LlamaDecoderLayer: Combines self-attention, optional cross-attention (controlled by use_cross_attn flag), and MLP with residual connections. Cross-attention is applied only when the sequence length is at least num_query_token (96) and vision_hidden_states is provided, enabling cache-mode generation to skip cross-attention.
- LlamaModel: Stacks decoder layers with cross-attention at regular intervals (every cross_attention_frequency layers). Provides both forward (standard with causal mask preparation) and forward_train (for training with pre-computed attention masks) methods. The forward method supports use_zero_attention_mask to zero out attention within query token positions.
- LlamaForCausalLM: Adds a linear language model head for next-token prediction with cross-entropy loss.
Usage
Use QLLaMA as the language model component within the InternVL-14B architecture, where it processes learnable query tokens that attend to InternViT visual features through cross-attention layers to produce vision-language aligned representations.
Code Reference
Source Location
- Repository: OpenGVLab_InternVL
- File: internvl_chat_llava/llava/model/multimodal_encoder/internvl_14b/modeling_qllama.py
- Lines: 1-1073
Signature
class LlamaModel(LlamaPreTrainedModel):
def __init__(self, config: LlamaConfig):
...
def forward(self, input_ids=None, attention_mask=None, position_ids=None,
past_key_values=None, inputs_embeds=None,
vision_hidden_states=None, repeat_time=1,
use_cache=None, output_attentions=None,
output_hidden_states=None, use_zero_attention_mask=None,
return_dict=None):
...
def forward_train(self, input_ids=None, attention_mask=None,
position_ids=None, past_key_values=None,
inputs_embeds=None, vision_hidden_states=None,
repeat_time=1, ...):
...
class LlamaForCausalLM(LlamaPreTrainedModel):
def __init__(self, config):
...
def forward(self, input_ids=None, attention_mask=None, position_ids=None,
past_key_values=None, inputs_embeds=None,
vision_hidden_states=None, labels=None, ...):
...
Import
from internvl_chat_llava.llava.model.multimodal_encoder.internvl_14b.modeling_qllama import (
LlamaForCausalLM,
LlamaModel,
LlamaCrossAttention,
_expand_mask,
_make_causal_mask,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.LongTensor [batch, seq_len] | Yes (or inputs_embeds) | Token IDs for the language model |
| inputs_embeds | torch.FloatTensor [batch, seq_len, hidden_size] | No | Pre-computed input embeddings (alternative to input_ids) |
| vision_hidden_states | torch.FloatTensor [batch, vis_seq_len, 3200] | No | Visual features from InternViT for cross-attention |
| attention_mask | torch.Tensor [batch, seq_len] | No | Attention mask for text tokens |
| position_ids | torch.LongTensor [batch, seq_len] | No | Position IDs for rotary embeddings |
| labels | torch.LongTensor [batch, seq_len] | No | Target labels for language modeling loss |
| past_key_values | List[Tuple[torch.Tensor]] | No | Cached KV states for generation |
| use_zero_attention_mask | bool | No | Whether to zero out attention among query tokens |
| repeat_time | int | No | How many times to repeat vision features (default 1) |
Outputs
| Name | Type | Description |
|---|---|---|
| last_hidden_state | torch.FloatTensor [batch, seq_len, hidden_size] | Hidden states from the last decoder layer |
| logits | torch.FloatTensor [batch, seq_len, vocab_size] | Language model prediction logits (LlamaForCausalLM) |
| loss | torch.FloatTensor | Cross-entropy loss when labels are provided |
| past_key_values | tuple | Cached key-value states for autoregressive generation |
Usage Examples
Basic Usage
from internvl_chat_llava.llava.model.multimodal_encoder.internvl_14b.modeling_qllama import (
LlamaForCausalLM
)
from transformers import LlamaConfig
# Configure QLLaMA with cross-attention
config = LlamaConfig(
hidden_size=4096, num_attention_heads=32, num_hidden_layers=32,
intermediate_size=11008, cross_attention_frequency=4,
num_query_token=96,
)
model = LlamaForCausalLM(config)
# Forward with vision features (cross-attention enabled)
outputs = model(
inputs_embeds=query_and_text_embeds,
vision_hidden_states=vision_features, # from InternViT
attention_mask=attention_mask,
use_zero_attention_mask=True,
)
# Training with pre-computed attention mask
hidden_states = model.model.forward_train(
inputs_embeds=embeds,
vision_hidden_states=vision_features,
attention_mask=precomputed_mask,
).last_hidden_state