Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:FlagOpen FlagEmbedding BGE M3 Data

From Leeroopedia


Knowledge Sources
Domains Machine Learning, Data Loading, Multi-Modal Retrieval, Information Retrieval
Last Updated 2026-02-09 00:00 GMT

Overview

A specialized data loading system for BGE-M3 model training that implements dynamic batch sizing based on document length to optimize GPU memory usage.

Description

This module implements SameDatasetTrainDataset, a sophisticated PyTorch Dataset class designed specifically for training the BGE-M3 (Multi-Functionality, Multi-Linguality, Multi-Granularity) embedding model. The key innovation is dynamic batch size adjustment based on document length to maximize GPU utilization while preventing out-of-memory errors.

The dataset ensures all samples in a batch come from the same task to maintain training stability. It supports multiple features including knowledge distillation with teacher scores, parallel corpus handling with bidirectional loss, configurable text shuffling for data augmentation, automatic deduplication and filtering of small datasets, and instruction-augmented queries and passages.

The module integrates with HuggingFace's datasets library and implements intelligent caching, deterministic shuffling with seed control, and per-GPU data distribution for distributed training.

Usage

Use this dataset class when training the BGE-M3 model to automatically handle varying document lengths, batch construction, and distributed training scenarios with optimal memory efficiency.

Code Reference

Source Location

Signature

class SameDatasetTrainDataset(Dataset):
    def __init__(
        self,
        args: DataArguments,
        batch_size: int,
        seed: int,
        process_index: int = 0,
        num_processes: int = 1
    )

    def __getitem__(self, _)
    def __len__(self) -> int
    def refresh_epoch(self)
    def create_batch_data(self, batch_raw_data)
    def shuffle_text(self, text: str) -> str

    @staticmethod
    def get_file_batch_size(file: str, batch_size: int, train_group_size: int) -> int

@dataclass
class EmbedCollator(DataCollatorWithPadding):
    query_max_len: int = 32
    passage_max_len: int = 128

    def __call__(self, features)

Import

from data import SameDatasetTrainDataset, EmbedCollator

I/O Contract

Inputs

Name Type Required Description
args DataArguments Yes Training data arguments with paths and configurations
batch_size int Yes Base batch size (will be adjusted per document length)
seed int Yes Random seed for deterministic shuffling
process_index int No Index of current GPU process (default: 0)
num_processes int No Total number of GPU processes (default: 1)

Outputs

Name Type Description
queries List[str] List of query strings (with optional instruction prefix)
passages List[str] List of passage strings (1 pos + N-1 neg per query)
teacher_scores Optional[torch.FloatTensor] Teacher model scores for knowledge distillation
pqloss_flag bool Whether to use bidirectional (parallel corpus) loss

Data Format

Input JSONL Format

Each line in training data files should be:

{
    "query": "What is machine learning?",
    "pos": ["Machine learning is a subset of AI...", "ML uses algorithms..."],
    "neg": ["Deep learning is...", "Natural language processing..."]
}

Or with knowledge distillation scores:

{
    "query": "What is machine learning?",
    "pos": ["Machine learning is..."],
    "neg": ["Deep learning is..."],
    "pos_scores": [0.95],
    "neg_scores": [0.3, 0.25, 0.15]
}

Output Batch Format

{
    "query": {
        "input_ids": torch.Tensor,      # (batch_size, query_max_len)
        "attention_mask": torch.Tensor   # (batch_size, query_max_len)
    },
    "passage": {
        "input_ids": torch.Tensor,       # (batch_size * group_size, passage_max_len)
        "attention_mask": torch.Tensor   # (batch_size * group_size, passage_max_len)
    },
    "teacher_scores": Optional[torch.FloatTensor],  # (batch_size, group_size)
    "bi_directions": bool                            # Flag for parallel corpus
}

Dynamic Batch Sizing

The key innovation is automatic batch size adjustment based on document length:

Batch Size Table (train_group_size=8, 80GB GPU)

Document Length Batch Size Memory Optimization
0-500 tokens 48 High throughput for short documents
500-1000 tokens 32 Balanced processing
1000-2000 tokens 20 Medium documents
2000-3000 tokens 18 Longer documents
3000-4000 tokens 14 Large documents
4000-5000 tokens 14 Large documents
5000-6000 tokens 12 Very large documents
6000-7000 tokens 10 Near max length
7000+ tokens 8 Maximum length documents

Batch Size Table (train_group_size=1, 80GB GPU)

Document Length Batch Size
0-500 tokens 700
500-1000 tokens 570
1000-2000 tokens 388
2000-3000 tokens 288
3000-4000 tokens 224
4000-5000 tokens 180
5000-6000 tokens 157
6000-7000 tokens 128
7000+ tokens 104

Dataset Organization

Hierarchical Structure

1. Dataset Level: Multiple task directories under train_data paths 2. File Level: JSONL files within each directory, named by length (e.g., len-0-500.jsonl) 3. Batch Level: Samples from same file/task grouped into batches 4. Sample Level: Individual query-positive-negative triplets

Small Dataset Handling

  • Files with < SMALL_THRESHOLD samples are combined
  • Combined dataset must have >= DROP_THRESHOLD samples
  • Otherwise dropped to avoid training instability
  • Default: SMALL_THRESHOLD = 1000, DROP_THRESHOLD = 100

Dataset Shuffling

1. Shuffle order of datasets (task types) 2. Shuffle samples within each dataset 3. Create batches with task-specific batch sizes 4. Shuffle batches globally 5. Repeat every epoch with deterministic seed

Special Features

Parallel Corpus Support

  • Directories with "parallel_" prefix are marked as parallel corpora
  • These use bidirectional loss (pqloss_flag=True)
  • Used for multilingual parallel text training

Text Shuffling Augmentation

Configurable text chunking and shuffling:

  • Only applied to passages > 100 characters
  • Controlled by shuffle_ratio parameter
  • Splits text into 3 chunks and randomly shuffles
  • Helps model learn position-invariant representations

Knowledge Distillation

  • Supports pos_scores and neg_scores fields
  • Scores sorted in descending order for negatives
  • Enables soft-label distillation during training
  • Optional: falls back to hard labels if not present

Instruction Augmentation

  • Optional query_instruction_for_retrieval prefix
  • Optional passage_instruction_for_retrieval prefix
  • Helps with task-specific fine-tuning

Usage Examples

from transformers import AutoTokenizer
from data import SameDatasetTrainDataset, EmbedCollator, DataArguments

# Define data arguments
data_args = DataArguments(
    train_data=[
        "/data/retrieval/msmarco",
        "/data/retrieval/nq",
        "/data/retrieval/parallel_en_zh"  # Parallel corpus
    ],
    train_group_size=8,
    query_instruction_for_retrieval="Represent this query for retrieval: ",
    passage_instruction_for_retrieval="",
    knowledge_distillation=True,
    shuffle_ratio=0.1,
    small_threshold=1000,
    drop_threshold=100,
    max_example_num_per_dataset=100000,
    cache_path="/tmp/cache"
)

# Create dataset
dataset = SameDatasetTrainDataset(
    args=data_args,
    batch_size=32,
    seed=42,
    process_index=0,  # GPU 0
    num_processes=8   # 8 GPUs total
)

print(f"Dataset size: {len(dataset)} batches")

# Create data collator
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3")
collator = EmbedCollator(
    tokenizer=tokenizer,
    query_max_len=512,
    passage_max_len=8192
)

# Create DataLoader
from torch.utils.data import DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=1,  # Dataset already returns batches
    collate_fn=collator,
    num_workers=0
)

# Training loop
for batch in dataloader:
    queries = batch["query"]
    passages = batch["passage"]
    teacher_scores = batch["teacher_scores"]
    bi_directions = batch["bi_directions"]

    # Forward pass through model
    # model(queries, passages, teacher_scores, bi_directions)

# Access individual batch
batch_idx = 0
queries, passages, teacher_scores, pqloss_flag = dataset[batch_idx]

print(f"Queries: {len(queries)}")
print(f"Passages: {len(passages)}")
print(f"Teacher scores: {teacher_scores.shape if teacher_scores else None}")
print(f"Parallel corpus: {pqloss_flag}")

# Refresh data for next epoch
dataset.refresh_epoch()

Configuration Parameters

DataArguments Fields

  • train_data: List[str] - Paths to training data directories
  • train_group_size: int - Number of passages per query (1 pos + N-1 neg)
  • query_instruction_for_retrieval: Optional[str] - Query prefix
  • passage_instruction_for_retrieval: Optional[str] - Passage prefix
  • knowledge_distillation: bool - Use teacher scores
  • shuffle_ratio: float - Probability of text shuffling (0.0-1.0)
  • small_threshold: int - Threshold for combining small datasets
  • drop_threshold: int - Minimum size for combined datasets
  • max_example_num_per_dataset: Optional[int] - Sample limit per dataset
  • cache_path: Optional[str] - HuggingFace datasets cache directory

EmbedCollator Parameters

  • tokenizer: PreTrainedTokenizer - Tokenizer for text encoding
  • query_max_len: int - Maximum query length (default: 32)
  • passage_max_len: int - Maximum passage length (default: 128)
  • padding: bool - Enable padding (default: True)

Distributed Training Support

The dataset handles multi-GPU training:

  • Each GPU gets different subset of batches
  • process_index determines which batches to return
  • num_processes determines total GPU count
  • Batch division: batch_indices[process_index * size : (process_index + 1) * size]
  • Ensures no data overlap between GPUs
  • Maintains deterministic shuffling across GPUs

Memory Optimization

Techniques Used

1. Dynamic Batching: Adjust batch size based on document length 2. Lazy Loading: Load data only when needed 3. Efficient Caching: Use HuggingFace's dataset caching 4. Batch Filtering: Drop last batch if too small (< 2 * num_processes) 5. Sample Deduplication: Merge small datasets to reduce overhead

Memory Footprint

  • Dataset metadata: O(number of datasets)
  • Batch indices: O(number of batches)
  • Actual data: Loaded on-demand per batch
  • Cache: Disk-based, not in RAM

Error Handling

  • Validates train_data directories exist
  • Handles missing pos_scores/neg_scores gracefully
  • Skips empty datasets automatically
  • Drops datasets below drop_threshold
  • Robust to malformed JSONL entries

Performance Considerations

  • Deterministic shuffling ensures reproducibility
  • Efficient indexing with NumPy arrays
  • Minimal memory overhead with lazy loading
  • Optimized for large-scale training (100M+ examples)
  • Compatible with HuggingFace Trainer

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment