Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:OpenGVLab InternVL MPT Model

From Leeroopedia


Knowledge Sources
Domains Language Modeling, Transformer Architecture, LLaVA
Last Updated 2026-02-07 14:00 GMT

Overview

This module implements the MosaicPreTrainedTransformer (MPT) model architecture as a causal language model backbone for the LLaVA multimodal framework.

Description

The implementation provides two main classes: MPTModel (the base transformer) and MPTForCausalLM (with a language model head). MPTModel stacks word embeddings (SharedEmbedding), optional positional embeddings (disabled when using ALiBi attention biases), embedding dropout, MPTBlock layers, and a final layer normalization. It supports prefix-LM masking via a prefix_mask input, sequence-level attention isolation via sequence_id, and gradient checkpointing for memory efficiency.

MPTForCausalLM wraps MPTModel and adds a tied output embedding head with optional logit scaling (either a numeric value or inv_sqrt_d_model). The model supports KV caching for efficient autoregressive generation and implements beam search reordering of past key values.

Key architectural features include configurable attention implementations (torch, flash, triton), ALiBi positional biases as an alternative to learned position embeddings, FSDP wrapping at the MPTBlock level, and mixed device initialization for distributed training.

Usage

Use this model as an alternative language model backbone to LLaMA within the LLaVA multimodal framework, particularly when ALiBi attention or prefix-LM behavior is desired.

Code Reference

Source Location

Signature

class MPTPreTrainedModel(PreTrainedModel):
    config_class = MPTConfig
    base_model_prefix = 'model'

class MPTModel(MPTPreTrainedModel):
    def __init__(self, config: MPTConfig):
        ...
    def forward(self, input_ids, past_key_values=None, attention_mask=None,
                prefix_mask=None, sequence_id=None, return_dict=None,
                output_attentions=None, output_hidden_states=None,
                use_cache=None, inputs_embeds=None):
        ...

class MPTForCausalLM(MPTPreTrainedModel):
    def __init__(self, config: MPTConfig):
        ...
    def forward(self, input_ids, past_key_values=None, attention_mask=None,
                prefix_mask=None, sequence_id=None, labels=None,
                return_dict=None, output_attentions=None,
                output_hidden_states=None, use_cache=None,
                inputs_embeds=None):
        ...

Import

from internvl_chat_llava.llava.model.language_model.mpt.modeling_mpt import (
    MPTModel,
    MPTForCausalLM,
)

I/O Contract

Inputs

Name Type Required Description
input_ids torch.LongTensor Yes (or inputs_embeds) Token IDs of shape [batch_size, seq_length]
attention_mask torch.ByteTensor No Mask for padding tokens
prefix_mask torch.ByteTensor No Mask indicating prefix tokens for prefix-LM mode
sequence_id torch.LongTensor No Sequence IDs for attention isolation in packed sequences
labels torch.LongTensor No Target labels for computing cross-entropy loss
past_key_values List[Tuple[torch.FloatTensor]] No Cached key-value states for autoregressive generation
use_cache bool No Whether to return past_key_values for caching

Outputs

Name Type Description
last_hidden_state torch.Tensor Hidden states from the final layer [batch_size, seq_length, d_model]
logits torch.Tensor Language model logits [batch_size, seq_length, vocab_size] (MPTForCausalLM only)
loss torch.Tensor Cross-entropy loss when labels are provided
past_key_values List[Tuple[torch.Tensor]] Cached key-value states for generation

Usage Examples

Basic Usage

from internvl_chat_llava.llava.model.language_model.mpt.modeling_mpt import MPTForCausalLM
from internvl_chat_llava.llava.model.language_model.mpt.configuration_mpt import MPTConfig

config = MPTConfig(d_model=2048, n_heads=16, n_layers=24, vocab_size=50432)
model = MPTForCausalLM(config)

# Forward pass with labels for training
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss

# Generation with prefix mask
outputs = model(input_ids=input_ids, attention_mask=attention_mask,
                prefix_mask=prefix_mask, use_cache=True)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment