Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:FlagOpen FlagEmbedding RetroMAE Modeling

From Leeroopedia


Knowledge Sources
Domains Machine_Learning, Natural_Language_Processing, Self_Supervised_Learning
Last Updated 2026-02-09 00:00 GMT

Overview

RetroMAE model for pre-training text encoders with masked auto-encoding objectives optimized for retrieval tasks.

Description

RetroMAEForPretraining implements the RetroMAE (Retrieval-oriented Masked Auto-Encoder) architecture combining:

Encoder: A standard BERT-style masked language model that processes masked input and predicts original tokens. The [CLS] token embedding from the final layer serves as the sentence representation.

Enhanced Decoder: Uses a custom decoder layer (BertLayerForDecoder) that takes:

  • Query: Position embeddings combined with the encoder's [CLS] token, expanded to all positions
  • Key/Value: Decoder embeddings (concatenation of encoder [CLS] and original token embeddings from position 1 onward)
  • Matrix attention mask: Allows each position to attend to a different masked pattern of the sequence

The decoder predicts the full original sequence from these varied masked contexts. Training combines both encoder MLM loss and decoder reconstruction loss. The model supports gradient checkpointing for memory efficiency and can be saved/loaded with standard HuggingFace methods.

Usage

Use this for pre-training BERT-like models on unlabeled text corpora with an objective specifically designed to produce high-quality embeddings for retrieval applications.

Code Reference

Source Location

Signature

class RetroMAEForPretraining(nn.Module):
    def __init__(self, bert: BertForMaskedLM, model_args: ModelArguments)

    def forward(self, encoder_input_ids, encoder_attention_mask, encoder_labels,
                decoder_input_ids, decoder_attention_mask, decoder_labels)

    @classmethod
    def from_pretrained(cls, model_args: ModelArguments, *args, **kwargs)

Import

from research.baai_general_embedding.retromae_pretrain.modeling import RetroMAEForPretraining

I/O Contract

Inputs

Name Type Required Description
bert BertForMaskedLM Yes Pre-trained BERT model serving as encoder
model_args ModelArguments Yes Model configuration arguments
encoder_input_ids Tensor Yes Masked input IDs for encoder [batch, seq_len]
encoder_attention_mask Tensor Yes Attention mask for encoder [batch, seq_len]
encoder_labels Tensor Yes Original tokens for MLM loss [batch, seq_len]
decoder_input_ids Tensor Yes Original input IDs for decoder [batch, seq_len]
decoder_attention_mask Tensor Yes Matrix attention mask [batch, seq_len, seq_len]
decoder_labels Tensor Yes Labels for decoder reconstruction [batch, seq_len]

Outputs

Name Type Description
loss Tensor Combined encoder MLM loss + decoder reconstruction loss

Usage Examples

from transformers import BertForMaskedLM, AutoTokenizer
from research.baai_general_embedding.retromae_pretrain.modeling import RetroMAEForPretraining
from research.baai_general_embedding.retromae_pretrain.arguments import ModelArguments
from research.baai_general_embedding.retromae_pretrain.data import RetroMAECollator

# Initialize model
model_args = ModelArguments()
bert = BertForMaskedLM.from_pretrained("bert-base-uncased")
model = RetroMAEForPretraining(bert, model_args)

# Prepare data
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
collator = RetroMAECollator(tokenizer=tokenizer)
texts = ["This is a sample text for pre-training.", ...]
batch = collator(texts)

# Forward pass
loss = model(
    encoder_input_ids=batch["encoder_input_ids"],
    encoder_attention_mask=batch["encoder_attention_mask"],
    encoder_labels=batch["encoder_labels"],
    decoder_input_ids=batch["decoder_input_ids"],
    decoder_attention_mask=batch["decoder_attention_mask"],
    decoder_labels=batch["decoder_labels"]
)

# Training
loss[0].backward()

# Save model
model.save_pretrained("output/retromae_model")

# Load model
model = RetroMAEForPretraining.from_pretrained(
    model_args,
    "output/retromae_model"
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment