Implementation:FlagOpen FlagEmbedding BGE M3 Data
| Knowledge Sources | |
|---|---|
| Domains | Machine Learning, Data Loading, Multi-Modal Retrieval, Information Retrieval |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
A specialized data loading system for BGE-M3 model training that implements dynamic batch sizing based on document length to optimize GPU memory usage.
Description
This module implements SameDatasetTrainDataset, a sophisticated PyTorch Dataset class designed specifically for training the BGE-M3 (Multi-Functionality, Multi-Linguality, Multi-Granularity) embedding model. The key innovation is dynamic batch size adjustment based on document length to maximize GPU utilization while preventing out-of-memory errors.
The dataset ensures all samples in a batch come from the same task to maintain training stability. It supports multiple features including knowledge distillation with teacher scores, parallel corpus handling with bidirectional loss, configurable text shuffling for data augmentation, automatic deduplication and filtering of small datasets, and instruction-augmented queries and passages.
The module integrates with HuggingFace's datasets library and implements intelligent caching, deterministic shuffling with seed control, and per-GPU data distribution for distributed training.
Usage
Use this dataset class when training the BGE-M3 model to automatically handle varying document lengths, batch construction, and distributed training scenarios with optimal memory efficiency.
Code Reference
Source Location
- Repository: FlagOpen_FlagEmbedding
- File: research/BGE_M3/data.py
- Lines: 1-301
Signature
class SameDatasetTrainDataset(Dataset):
def __init__(
self,
args: DataArguments,
batch_size: int,
seed: int,
process_index: int = 0,
num_processes: int = 1
)
def __getitem__(self, _)
def __len__(self) -> int
def refresh_epoch(self)
def create_batch_data(self, batch_raw_data)
def shuffle_text(self, text: str) -> str
@staticmethod
def get_file_batch_size(file: str, batch_size: int, train_group_size: int) -> int
@dataclass
class EmbedCollator(DataCollatorWithPadding):
query_max_len: int = 32
passage_max_len: int = 128
def __call__(self, features)
Import
from data import SameDatasetTrainDataset, EmbedCollator
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| args | DataArguments | Yes | Training data arguments with paths and configurations |
| batch_size | int | Yes | Base batch size (will be adjusted per document length) |
| seed | int | Yes | Random seed for deterministic shuffling |
| process_index | int | No | Index of current GPU process (default: 0) |
| num_processes | int | No | Total number of GPU processes (default: 1) |
Outputs
| Name | Type | Description |
|---|---|---|
| queries | List[str] | List of query strings (with optional instruction prefix) |
| passages | List[str] | List of passage strings (1 pos + N-1 neg per query) |
| teacher_scores | Optional[torch.FloatTensor] | Teacher model scores for knowledge distillation |
| pqloss_flag | bool | Whether to use bidirectional (parallel corpus) loss |
Data Format
Input JSONL Format
Each line in training data files should be:
{
"query": "What is machine learning?",
"pos": ["Machine learning is a subset of AI...", "ML uses algorithms..."],
"neg": ["Deep learning is...", "Natural language processing..."]
}
Or with knowledge distillation scores:
{
"query": "What is machine learning?",
"pos": ["Machine learning is..."],
"neg": ["Deep learning is..."],
"pos_scores": [0.95],
"neg_scores": [0.3, 0.25, 0.15]
}
Output Batch Format
{
"query": {
"input_ids": torch.Tensor, # (batch_size, query_max_len)
"attention_mask": torch.Tensor # (batch_size, query_max_len)
},
"passage": {
"input_ids": torch.Tensor, # (batch_size * group_size, passage_max_len)
"attention_mask": torch.Tensor # (batch_size * group_size, passage_max_len)
},
"teacher_scores": Optional[torch.FloatTensor], # (batch_size, group_size)
"bi_directions": bool # Flag for parallel corpus
}
Dynamic Batch Sizing
The key innovation is automatic batch size adjustment based on document length:
Batch Size Table (train_group_size=8, 80GB GPU)
| Document Length | Batch Size | Memory Optimization |
|---|---|---|
| 0-500 tokens | 48 | High throughput for short documents |
| 500-1000 tokens | 32 | Balanced processing |
| 1000-2000 tokens | 20 | Medium documents |
| 2000-3000 tokens | 18 | Longer documents |
| 3000-4000 tokens | 14 | Large documents |
| 4000-5000 tokens | 14 | Large documents |
| 5000-6000 tokens | 12 | Very large documents |
| 6000-7000 tokens | 10 | Near max length |
| 7000+ tokens | 8 | Maximum length documents |
Batch Size Table (train_group_size=1, 80GB GPU)
| Document Length | Batch Size |
|---|---|
| 0-500 tokens | 700 |
| 500-1000 tokens | 570 |
| 1000-2000 tokens | 388 |
| 2000-3000 tokens | 288 |
| 3000-4000 tokens | 224 |
| 4000-5000 tokens | 180 |
| 5000-6000 tokens | 157 |
| 6000-7000 tokens | 128 |
| 7000+ tokens | 104 |
Dataset Organization
Hierarchical Structure
1. Dataset Level: Multiple task directories under train_data paths 2. File Level: JSONL files within each directory, named by length (e.g., len-0-500.jsonl) 3. Batch Level: Samples from same file/task grouped into batches 4. Sample Level: Individual query-positive-negative triplets
Small Dataset Handling
- Files with < SMALL_THRESHOLD samples are combined
- Combined dataset must have >= DROP_THRESHOLD samples
- Otherwise dropped to avoid training instability
- Default: SMALL_THRESHOLD = 1000, DROP_THRESHOLD = 100
Dataset Shuffling
1. Shuffle order of datasets (task types) 2. Shuffle samples within each dataset 3. Create batches with task-specific batch sizes 4. Shuffle batches globally 5. Repeat every epoch with deterministic seed
Special Features
Parallel Corpus Support
- Directories with "parallel_" prefix are marked as parallel corpora
- These use bidirectional loss (pqloss_flag=True)
- Used for multilingual parallel text training
Text Shuffling Augmentation
Configurable text chunking and shuffling:
- Only applied to passages > 100 characters
- Controlled by shuffle_ratio parameter
- Splits text into 3 chunks and randomly shuffles
- Helps model learn position-invariant representations
Knowledge Distillation
- Supports pos_scores and neg_scores fields
- Scores sorted in descending order for negatives
- Enables soft-label distillation during training
- Optional: falls back to hard labels if not present
Instruction Augmentation
- Optional query_instruction_for_retrieval prefix
- Optional passage_instruction_for_retrieval prefix
- Helps with task-specific fine-tuning
Usage Examples
from transformers import AutoTokenizer
from data import SameDatasetTrainDataset, EmbedCollator, DataArguments
# Define data arguments
data_args = DataArguments(
train_data=[
"/data/retrieval/msmarco",
"/data/retrieval/nq",
"/data/retrieval/parallel_en_zh" # Parallel corpus
],
train_group_size=8,
query_instruction_for_retrieval="Represent this query for retrieval: ",
passage_instruction_for_retrieval="",
knowledge_distillation=True,
shuffle_ratio=0.1,
small_threshold=1000,
drop_threshold=100,
max_example_num_per_dataset=100000,
cache_path="/tmp/cache"
)
# Create dataset
dataset = SameDatasetTrainDataset(
args=data_args,
batch_size=32,
seed=42,
process_index=0, # GPU 0
num_processes=8 # 8 GPUs total
)
print(f"Dataset size: {len(dataset)} batches")
# Create data collator
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3")
collator = EmbedCollator(
tokenizer=tokenizer,
query_max_len=512,
passage_max_len=8192
)
# Create DataLoader
from torch.utils.data import DataLoader
dataloader = DataLoader(
dataset,
batch_size=1, # Dataset already returns batches
collate_fn=collator,
num_workers=0
)
# Training loop
for batch in dataloader:
queries = batch["query"]
passages = batch["passage"]
teacher_scores = batch["teacher_scores"]
bi_directions = batch["bi_directions"]
# Forward pass through model
# model(queries, passages, teacher_scores, bi_directions)
# Access individual batch
batch_idx = 0
queries, passages, teacher_scores, pqloss_flag = dataset[batch_idx]
print(f"Queries: {len(queries)}")
print(f"Passages: {len(passages)}")
print(f"Teacher scores: {teacher_scores.shape if teacher_scores else None}")
print(f"Parallel corpus: {pqloss_flag}")
# Refresh data for next epoch
dataset.refresh_epoch()
Configuration Parameters
DataArguments Fields
- train_data: List[str] - Paths to training data directories
- train_group_size: int - Number of passages per query (1 pos + N-1 neg)
- query_instruction_for_retrieval: Optional[str] - Query prefix
- passage_instruction_for_retrieval: Optional[str] - Passage prefix
- knowledge_distillation: bool - Use teacher scores
- shuffle_ratio: float - Probability of text shuffling (0.0-1.0)
- small_threshold: int - Threshold for combining small datasets
- drop_threshold: int - Minimum size for combined datasets
- max_example_num_per_dataset: Optional[int] - Sample limit per dataset
- cache_path: Optional[str] - HuggingFace datasets cache directory
EmbedCollator Parameters
- tokenizer: PreTrainedTokenizer - Tokenizer for text encoding
- query_max_len: int - Maximum query length (default: 32)
- passage_max_len: int - Maximum passage length (default: 128)
- padding: bool - Enable padding (default: True)
Distributed Training Support
The dataset handles multi-GPU training:
- Each GPU gets different subset of batches
- process_index determines which batches to return
- num_processes determines total GPU count
- Batch division: batch_indices[process_index * size : (process_index + 1) * size]
- Ensures no data overlap between GPUs
- Maintains deterministic shuffling across GPUs
Memory Optimization
Techniques Used
1. Dynamic Batching: Adjust batch size based on document length 2. Lazy Loading: Load data only when needed 3. Efficient Caching: Use HuggingFace's dataset caching 4. Batch Filtering: Drop last batch if too small (< 2 * num_processes) 5. Sample Deduplication: Merge small datasets to reduce overhead
Memory Footprint
- Dataset metadata: O(number of datasets)
- Batch indices: O(number of batches)
- Actual data: Loaded on-demand per batch
- Cache: Disk-based, not in RAM
Error Handling
- Validates train_data directories exist
- Handles missing pos_scores/neg_scores gracefully
- Skips empty datasets automatically
- Drops datasets below drop_threshold
- Robust to malformed JSONL entries
Performance Considerations
- Deterministic shuffling ensures reproducibility
- Efficient indexing with NumPy arrays
- Minimal memory overhead with lazy loading
- Optimized for large-scale training (100M+ examples)
- Compatible with HuggingFace Trainer