Implementation:FlagOpen FlagEmbedding LLARA Pretrain Data
| Knowledge Sources | |
|---|---|
| Domains | LLM Pretraining, Embedding, Data Processing |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Dataset and collator for pretraining LLARA embeddings with dual-task summarization and prediction objectives.
Description
This module provides data loading infrastructure for pretraining LLARA (LLM as Retrieval Adapter) with a dual-task objective. It implements TrainDatasetForEmbedding that loads input text with dual outputs (summarize + predict), applies special token formatting with separate suffix tokens for each task (<s1>-<s8> for summarize, <s9>-<s16> for predict), optional stopword removal for output targets, and support for NLTK stopword filtering. The EmbedCollator handles tokenization, padding, label masking for language modeling loss, and separate encoding of output targets for contrastive or similarity objectives. The pretraining uses a dual objective to learn both backward-looking (summarization) and forward-looking (prediction) representations.
Usage
Use this module when pretraining LLARA or similar LLM-based embedding models from scratch, implementing dual-task learning objectives for better representation quality, and preparing data with input-output pairs where the model learns to both summarize and predict. The module is designed for the initial pretraining phase before fine-tuning.
Code Reference
Source Location
- Repository: FlagOpen_FlagEmbedding
- File: research/LLARA/pretrain/data.py
- Lines: 1-171
Signature
class TrainDatasetForEmbedding(Dataset):
def __init__(
self,
tokenizer: PreTrainedTokenizer,
args: DataArguments,
):
pass
def __getitem__(self, item):
"""Returns (input_dict, output_summarize, output_predict)"""
@dataclass
class EmbedCollator(DataCollatorForSeq2Seq):
cutoff_len: int = 512
def __call__(self, features, return_tensors='pt'):
"""Returns dict with input_ids, labels, and output token IDs"""
Import
from data import TrainDatasetForEmbedding, EmbedCollator
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| train_data | str | Yes | Path to training JSONL file or directory |
| tokenizer | PreTrainedTokenizer | Yes | Tokenizer for the LLM |
| cutoff_len | int | Yes | Maximum sequence length |
| remove_stop_words | bool | No | Remove stopwords from outputs (default: False) |
| max_example_num_per_dataset | int | No | Max examples per file |
| cache_path | str | No | Cache directory for datasets |
Outputs
| Name | Type | Description |
|---|---|---|
| input_ids | Tensor | Tokenized input sequences |
| attention_mask | Tensor | Attention masks |
| labels | Tensor | Labels for language modeling (-100 for special tokens) |
| output_summarize_ids | Tensor | Tokenized summarization outputs |
| output_predict_ids | Tensor | Tokenized prediction outputs |
Usage Examples
# Example 1: Create pretraining dataset
from transformers import AutoTokenizer
from data import TrainDatasetForEmbedding, EmbedCollator
from torch.utils.data import DataLoader
from dataclasses import dataclass
@dataclass
class DataArguments:
train_data: str = "./pretrain_data"
cutoff_len: int = 512
remove_stop_words: bool = True
max_example_num_per_dataset: int = 100000
cache_path: str = ".cache"
# Initialize
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
args = DataArguments()
dataset = TrainDatasetForEmbedding(tokenizer=tokenizer, args=args)
collator = EmbedCollator(tokenizer=tokenizer, cutoff_len=args.cutoff_len)
dataloader = DataLoader(
dataset,
batch_size=8,
collate_fn=collator,
shuffle=True
)
# Iterate
for batch in dataloader:
print(f"Input IDs: {batch['input_ids'].shape}")
print(f"Labels: {batch['labels'].shape}")
print(f"Summarize outputs: {batch['output_summarize_ids'].shape}")
print(f"Predict outputs: {batch['output_predict_ids'].shape}")
break
# Example 2: Data format
# Training data should be JSONL with format:
# {
# "input": "Machine learning is a subset of artificial intelligence...",
# "output_summarize": "machine learning AI subset",
# "output_predict": "algorithms data patterns learn"
# }
# Example 3: With stopword removal
collator_no_stop = EmbedCollator(
tokenizer=tokenizer,
cutoff_len=512
)
# The dataset will automatically remove stopwords from outputs if configured
dataset_filtered = TrainDatasetForEmbedding(
tokenizer=tokenizer,
args=DataArguments(remove_stop_words=True)
)
# Example 4: Training loop integration
import torch
from torch import nn
model = nn.Module() # Your LLARA model
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(3):
for batch in dataloader:
input_ids = batch['input_ids']
labels = batch['labels']
output_sum = batch['output_summarize_ids']
output_pred = batch['output_predict_ids']
# Forward pass with dual objectives
outputs = model(
input_ids=input_ids,
labels=labels,
output_summarize=output_sum,
output_predict=output_pred
)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()