Implementation:FlagOpen FlagEmbedding RetroMAE Modeling
| Knowledge Sources | |
|---|---|
| Domains | Machine_Learning, Natural_Language_Processing, Self_Supervised_Learning |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
RetroMAE model for pre-training text encoders with masked auto-encoding objectives optimized for retrieval tasks.
Description
RetroMAEForPretraining implements the RetroMAE (Retrieval-oriented Masked Auto-Encoder) architecture combining:
Encoder: A standard BERT-style masked language model that processes masked input and predicts original tokens. The [CLS] token embedding from the final layer serves as the sentence representation.
Enhanced Decoder: Uses a custom decoder layer (BertLayerForDecoder) that takes:
- Query: Position embeddings combined with the encoder's [CLS] token, expanded to all positions
- Key/Value: Decoder embeddings (concatenation of encoder [CLS] and original token embeddings from position 1 onward)
- Matrix attention mask: Allows each position to attend to a different masked pattern of the sequence
The decoder predicts the full original sequence from these varied masked contexts. Training combines both encoder MLM loss and decoder reconstruction loss. The model supports gradient checkpointing for memory efficiency and can be saved/loaded with standard HuggingFace methods.
Usage
Use this for pre-training BERT-like models on unlabeled text corpora with an objective specifically designed to produce high-quality embeddings for retrieval applications.
Code Reference
Source Location
- Repository: FlagOpen_FlagEmbedding
- File: research/baai_general_embedding/retromae_pretrain/modeling.py
- Lines: 1-102
Signature
class RetroMAEForPretraining(nn.Module):
def __init__(self, bert: BertForMaskedLM, model_args: ModelArguments)
def forward(self, encoder_input_ids, encoder_attention_mask, encoder_labels,
decoder_input_ids, decoder_attention_mask, decoder_labels)
@classmethod
def from_pretrained(cls, model_args: ModelArguments, *args, **kwargs)
Import
from research.baai_general_embedding.retromae_pretrain.modeling import RetroMAEForPretraining
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| bert | BertForMaskedLM | Yes | Pre-trained BERT model serving as encoder |
| model_args | ModelArguments | Yes | Model configuration arguments |
| encoder_input_ids | Tensor | Yes | Masked input IDs for encoder [batch, seq_len] |
| encoder_attention_mask | Tensor | Yes | Attention mask for encoder [batch, seq_len] |
| encoder_labels | Tensor | Yes | Original tokens for MLM loss [batch, seq_len] |
| decoder_input_ids | Tensor | Yes | Original input IDs for decoder [batch, seq_len] |
| decoder_attention_mask | Tensor | Yes | Matrix attention mask [batch, seq_len, seq_len] |
| decoder_labels | Tensor | Yes | Labels for decoder reconstruction [batch, seq_len] |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | Tensor | Combined encoder MLM loss + decoder reconstruction loss |
Usage Examples
from transformers import BertForMaskedLM, AutoTokenizer
from research.baai_general_embedding.retromae_pretrain.modeling import RetroMAEForPretraining
from research.baai_general_embedding.retromae_pretrain.arguments import ModelArguments
from research.baai_general_embedding.retromae_pretrain.data import RetroMAECollator
# Initialize model
model_args = ModelArguments()
bert = BertForMaskedLM.from_pretrained("bert-base-uncased")
model = RetroMAEForPretraining(bert, model_args)
# Prepare data
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
collator = RetroMAECollator(tokenizer=tokenizer)
texts = ["This is a sample text for pre-training.", ...]
batch = collator(texts)
# Forward pass
loss = model(
encoder_input_ids=batch["encoder_input_ids"],
encoder_attention_mask=batch["encoder_attention_mask"],
encoder_labels=batch["encoder_labels"],
decoder_input_ids=batch["decoder_input_ids"],
decoder_attention_mask=batch["decoder_attention_mask"],
decoder_labels=batch["decoder_labels"]
)
# Training
loss[0].backward()
# Save model
model.save_pretrained("output/retromae_model")
# Load model
model = RetroMAEForPretraining.from_pretrained(
model_args,
"output/retromae_model"
)