Implementation:OpenGVLab InternVL MPT Model
| Knowledge Sources | |
|---|---|
| Domains | Language Modeling, Transformer Architecture, LLaVA |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
This module implements the MosaicPreTrainedTransformer (MPT) model architecture as a causal language model backbone for the LLaVA multimodal framework.
Description
The implementation provides two main classes: MPTModel (the base transformer) and MPTForCausalLM (with a language model head). MPTModel stacks word embeddings (SharedEmbedding), optional positional embeddings (disabled when using ALiBi attention biases), embedding dropout, MPTBlock layers, and a final layer normalization. It supports prefix-LM masking via a prefix_mask input, sequence-level attention isolation via sequence_id, and gradient checkpointing for memory efficiency.
MPTForCausalLM wraps MPTModel and adds a tied output embedding head with optional logit scaling (either a numeric value or inv_sqrt_d_model). The model supports KV caching for efficient autoregressive generation and implements beam search reordering of past key values.
Key architectural features include configurable attention implementations (torch, flash, triton), ALiBi positional biases as an alternative to learned position embeddings, FSDP wrapping at the MPTBlock level, and mixed device initialization for distributed training.
Usage
Use this model as an alternative language model backbone to LLaMA within the LLaVA multimodal framework, particularly when ALiBi attention or prefix-LM behavior is desired.
Code Reference
Source Location
- Repository: OpenGVLab_InternVL
- File: internvl_chat_llava/llava/model/language_model/mpt/modeling_mpt.py
- Lines: 1-331
Signature
class MPTPreTrainedModel(PreTrainedModel):
config_class = MPTConfig
base_model_prefix = 'model'
class MPTModel(MPTPreTrainedModel):
def __init__(self, config: MPTConfig):
...
def forward(self, input_ids, past_key_values=None, attention_mask=None,
prefix_mask=None, sequence_id=None, return_dict=None,
output_attentions=None, output_hidden_states=None,
use_cache=None, inputs_embeds=None):
...
class MPTForCausalLM(MPTPreTrainedModel):
def __init__(self, config: MPTConfig):
...
def forward(self, input_ids, past_key_values=None, attention_mask=None,
prefix_mask=None, sequence_id=None, labels=None,
return_dict=None, output_attentions=None,
output_hidden_states=None, use_cache=None,
inputs_embeds=None):
...
Import
from internvl_chat_llava.llava.model.language_model.mpt.modeling_mpt import (
MPTModel,
MPTForCausalLM,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.LongTensor | Yes (or inputs_embeds) | Token IDs of shape [batch_size, seq_length] |
| attention_mask | torch.ByteTensor | No | Mask for padding tokens |
| prefix_mask | torch.ByteTensor | No | Mask indicating prefix tokens for prefix-LM mode |
| sequence_id | torch.LongTensor | No | Sequence IDs for attention isolation in packed sequences |
| labels | torch.LongTensor | No | Target labels for computing cross-entropy loss |
| past_key_values | List[Tuple[torch.FloatTensor]] | No | Cached key-value states for autoregressive generation |
| use_cache | bool | No | Whether to return past_key_values for caching |
Outputs
| Name | Type | Description |
|---|---|---|
| last_hidden_state | torch.Tensor | Hidden states from the final layer [batch_size, seq_length, d_model] |
| logits | torch.Tensor | Language model logits [batch_size, seq_length, vocab_size] (MPTForCausalLM only) |
| loss | torch.Tensor | Cross-entropy loss when labels are provided |
| past_key_values | List[Tuple[torch.Tensor]] | Cached key-value states for generation |
Usage Examples
Basic Usage
from internvl_chat_llava.llava.model.language_model.mpt.modeling_mpt import MPTForCausalLM
from internvl_chat_llava.llava.model.language_model.mpt.configuration_mpt import MPTConfig
config = MPTConfig(d_model=2048, n_heads=16, n_layers=24, vocab_size=50432)
model = MPTForCausalLM(config)
# Forward pass with labels for training
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
# Generation with prefix mask
outputs = model(input_ids=input_ids, attention_mask=attention_mask,
prefix_mask=prefix_mask, use_cache=True)