Implementation:OpenGVLab InternVL HF PrefixLM Converter
| Knowledge Sources | |
|---|---|
| Domains | Language Modeling, Attention Masking, Model Surgery |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
This module converts HuggingFace causal language models (GPT-2, GPT-Neo, GPT-NeoX, GPTJ, BLOOM, OPT) into prefix language models by monkey-patching their attention masking to support a bidirectional_mask input.
Description
The converter performs lightweight surgery on HuggingFace causal LMs to enable prefix-LM behavior. For each supported architecture, it replaces the model's forward and generate methods with wrapped versions that accept a bidirectional_mask tensor. When a bidirectional_mask is provided during forward, the causal attention mask is modified to allow bidirectional attention over prefix tokens (where mask=1) while retaining causal masking for target tokens (where mask=0). After the forward pass, the attention mask is restored to its original causal form.
For GPT-style models, this is achieved by manipulating the bias buffer in attention modules. For BLOOM, custom _prepare_attn_mask and _build_alibi_tensor methods are injected. For OPT, the _prepare_decoder_attention_mask method is replaced. The generate method is also wrapped to set fully bidirectional attention during prompt encoding, leveraging HuggingFace's KV caching to maintain proper causal behavior during autoregressive generation.
The helper function add_bidirectional_mask_if_missing automatically constructs the bidirectional_mask from labels and attention_mask when not explicitly provided in a batch.
Usage
Use this converter when employing MPT or other causal LMs in prefix-LM mode within the LLaVA framework, where bidirectional attention over the input prompt (including visual tokens) improves contextual understanding before autoregressive generation.
Code Reference
Source Location
- Repository: OpenGVLab_InternVL
- File: internvl_chat_llava/llava/model/language_model/mpt/hf_prefixlm_converter.py
- Lines: 1-415
Signature
def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
...
def add_bidirectional_mask_if_missing(batch: Dict[str, Any]):
...
Import
from internvl_chat_llava.llava.model.language_model.mpt.hf_prefixlm_converter import (
convert_hf_causal_lm_to_prefix_lm,
add_bidirectional_mask_if_missing,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | CAUSAL_LM_TYPES (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM, BloomForCausalLM, OPTForCausalLM) | Yes | HuggingFace causal language model to convert |
| bidirectional_mask | torch.Tensor [batch_size, seq_length] | No | Byte tensor where 1=prefix (bidirectional) and 0=target (causal) |
| batch | Dict[str, Any] | Yes (for add_bidirectional_mask_if_missing) | Training batch dict containing attention_mask and optionally labels |
Outputs
| Name | Type | Description |
|---|---|---|
| model | CAUSAL_LM_TYPES | The same model instance with patched forward/generate methods |
| batch | Dict[str, Any] | The batch dict with bidirectional_mask added (in-place) |
Usage Examples
Basic Usage
from transformers import GPT2LMHeadModel
from internvl_chat_llava.llava.model.language_model.mpt.hf_prefixlm_converter import (
convert_hf_causal_lm_to_prefix_lm,
add_bidirectional_mask_if_missing,
)
# Convert a GPT-2 model to prefix LM
model = GPT2LMHeadModel.from_pretrained("gpt2")
model = convert_hf_causal_lm_to_prefix_lm(model)
# Now model.forward() accepts bidirectional_mask
# During training, add bidirectional_mask to batch
batch = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
add_bidirectional_mask_if_missing(batch)
outputs = model(**batch)