Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft LoRA Run MLM Flax

From Leeroopedia


Knowledge Sources
Domains NLP, Language_Modeling, JAX_Flax
Last Updated 2026-02-10 06:00 GMT

Overview

HuggingFace Transformers example script for fine-tuning masked language models using the JAX/Flax backend with whole word masking support.

Description

run_mlm_flax.py fine-tunes masked language models using JAX/Flax instead of PyTorch. Unlike run_mlm.py which uses the HuggingFace Trainer API, this script implements a custom JAX training loop with explicit gradient computation, jax.pmap for multi-device parallelism, and a configurable learning rate scheduler. It uses FlaxBertForMaskedLM as its model class and includes a custom FlaxDataCollatorForLanguageModeling data collator implemented with NumPy operations. The script supports whole word masking for Chinese text via optional reference files (train_ref_file, validation_ref_file). TensorBoard logging is supported via flax.metrics.tensorboard.SummaryWriter. This script is part of the modified Transformers fork used by Microsoft LoRA for NLU experiments.

Usage

Use this script when fine-tuning a masked language model with JAX/Flax for hardware acceleration on TPUs or GPUs with JAX support. Supports both local files (CSV, JSON, TXT) and HuggingFace dataset hub datasets. The script is designed for multi-device training via jax.pmap with automatic data sharding across devices. Integrated with the LoRA-modified Transformers fork.

Code Reference

Source Location

Signature

# Script entry point via __main__ block with HfArgumentParser
# Key dataclasses:
@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default=None)
    model_type: Optional[str] = field(default=None)
    config_name: Optional[str] = field(default=None)
    tokenizer_name: Optional[str] = field(default=None)
    cache_dir: Optional[str] = field(default=None)
    use_fast_tokenizer: bool = field(default=True)

@dataclass
class DataTrainingArguments:
    dataset_name: Optional[str] = field(default=None)
    dataset_config_name: Optional[str] = field(default=None)
    train_file: Optional[str] = field(default=None)
    validation_file: Optional[str] = field(default=None)
    train_ref_file: Optional[str] = field(default=None)
    validation_ref_file: Optional[str] = field(default=None)
    overwrite_cache: bool = field(default=False)
    validation_split_percentage: Optional[int] = field(default=5)
    max_seq_length: Optional[int] = field(default=None)
    preprocessing_num_workers: Optional[int] = field(default=None)
    mlm_probability: float = field(default=0.15)
    pad_to_max_length: bool = field(default=False)

# Custom Flax data collator:
@dataclass
class FlaxDataCollatorForLanguageModeling:
    tokenizer: PreTrainedTokenizerBase
    mlm: bool = True
    mlm_probability: float = 0.15

# Custom helper functions:
def create_learning_rate_scheduler(...)
def compute_metrics(logits, labels, weights, label_smoothing=0.0)
def accuracy(logits, targets, weights=None)
def cross_entropy(logits, targets, weights=None, label_smoothing=0.0)
def training_step(optimizer, batch, dropout_rng)
def eval_step(params, batch)
def generate_batch_splits(samples_idx, batch_size)

Import

# Script is run directly, not imported
python examples/NLU/examples/language-modeling/run_mlm_flax.py \
    --model_name_or_path bert-base-cased \
    --dataset_name wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --output_dir /tmp/test-mlm-flax

Key Components

Model Loading

The script uses FlaxBertForMaskedLM (not AutoModelForMaskedLM) to load a BERT model in the Flax framework. The model is instantiated with explicit parameters including dtype=jnp.float32, input_shape derived from train_batch_size and config.max_position_embeddings, a seed for reproducibility, and dropout_rate=0.1. The model is hardcoded to load from "bert-base-cased" regardless of the model_name_or_path argument for config/tokenizer loading.

FlaxDataCollatorForLanguageModeling

A custom data collator class that operates with NumPy arrays instead of PyTorch tensors. It implements the standard BERT masking strategy:

  • 80% of selected tokens are replaced with the [MASK] token
  • 10% are replaced with a random token from the vocabulary
  • 10% are kept unchanged
  • Special tokens are never masked (protected by special_tokens_mask)
  • The collator pads examples using tokenizer.pad with pad_to_multiple_of support
  • Masking uses np.random.binomial for probabilistic token selection

Learning Rate Scheduler

create_learning_rate_scheduler implements a composable learning rate schedule specified as a string of factors separated by *. Supported factors:

  • constant: base learning rate
  • linear_warmup: linear warmup until warmup_steps
  • rsqrt_decay: divide by sqrt(max(step, warmup_steps))
  • rsqrt_normalized_decay: normalized rsqrt decay
  • decay_every: step decay every k steps
  • cosine_decay: cyclic cosine decay

The default schedule uses "constant * linear_warmup * rsqrt_decay".

Custom Training Loop

Instead of using the HuggingFace Trainer, this script implements a JAX-native training loop:

  1. training_step: computes loss via cross_entropy with label smoothing support, computes gradients via jax.value_and_grad, averages gradients across devices with jax.lax.pmean, and applies the gradient update with the learning rate scheduler
  2. eval_step: computes loss and accuracy metrics via compute_metrics, aggregated across devices with jax.lax.psum
  3. Both steps are parallelized across devices using jax.pmap with the "batch" axis name
  4. The optimizer is replicated across all devices with jax_utils.replicate
  5. Each epoch shuffles training indices with jax.random.permutation and generates batch splits

Optimizer

The script uses Flax's Adam optimizer (flax.optim.Adam), configured with learning_rate, weight_decay, adam_beta1, and adam_beta2 from TrainingArguments. The optimizer is created from the model params and replicated across devices.

Loss Functions

cross_entropy: computes cross-entropy loss with optional label smoothing. Uses common_utils.onehot for soft target construction and jax.nn.log_softmax for numerically stable computation. Only tokens with labels > 0 (non-padding, non-special) contribute to the loss.

accuracy: computes weighted accuracy by comparing argmax predictions against targets, weighted by a token mask.

Batch Generation

generate_batch_splits splits shuffled sample indices into batch-sized chunks. Samples that do not fill a complete batch are dropped (the remainder is discarded).

I/O Contract

Inputs

Name Type Required Description
model_name_or_path str No* Pretrained model name or path for config/tokenizer loading
model_type str No* Model type for training from scratch (e.g., bert)
dataset_name str No** HuggingFace dataset name (alternative to train_file)
dataset_config_name str No Configuration name for the HuggingFace dataset
train_file str No** Path to training text file (CSV, JSON, or TXT; alternative to dataset_name)
validation_file str No Path to validation text file
train_ref_file str No Reference file for whole word masking in Chinese (training)
validation_ref_file str No Reference file for whole word masking in Chinese (validation)
max_seq_length int No Maximum sequence length after tokenization (defaults to model max)
mlm_probability float No Ratio of tokens to mask (default: 0.15)
pad_to_max_length bool No Pad all samples to max_seq_length (default: False)
output_dir str Yes Directory to save logs and TensorBoard events
train_batch_size int No Per-device training batch size
eval_batch_size int No Per-device evaluation batch size
num_train_epochs int No Number of training epochs
learning_rate float No Base learning rate for Adam optimizer
warmup_steps int No Number of warmup steps for learning rate scheduler
validation_split_percentage int No Percentage of train set used as validation if no validation split exists (default: 5)
preprocessing_num_workers int No Number of processes for data preprocessing
overwrite_cache bool No Whether to overwrite cached preprocessed datasets (default: False)

* Either model_name_or_path or model_type must be provided.

** Either dataset_name or train_file/validation_file must be provided.

Outputs

Name Type Description
TensorBoard logs Events Training loss, eval loss, and accuracy written to output_dir/logs/ (if TensorBoard available)
console metrics Text Per-epoch loss and accuracy printed to stdout via tqdm progress bars

Usage Examples

Fine-tune BERT on WikiText with Flax

python examples/NLU/examples/language-modeling/run_mlm_flax.py \
    --model_name_or_path bert-base-cased \
    --dataset_name wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --do_train \
    --do_eval \
    --output_dir ./output/mlm-flax-bert \
    --train_batch_size 32 \
    --num_train_epochs 3 \
    --learning_rate 5e-5

Fine-tune on Custom Text Files

python examples/NLU/examples/language-modeling/run_mlm_flax.py \
    --model_name_or_path bert-base-cased \
    --train_file ./data/train.txt \
    --validation_file ./data/valid.txt \
    --do_train \
    --do_eval \
    --max_seq_length 128 \
    --mlm_probability 0.15 \
    --output_dir ./output/mlm-flax-custom \
    --overwrite_output_dir

Load Arguments from JSON

python examples/NLU/examples/language-modeling/run_mlm_flax.py config.json

Internal Details

Differences from run_mlm.py

Feature run_mlm.py run_mlm_flax.py
Framework PyTorch JAX/Flax
Training loop HuggingFace Trainer Custom JAX loop with jax.pmap
Model class AutoModelForMaskedLM FlaxBertForMaskedLM
Optimizer AdamW (PyTorch) flax.optim.Adam
Data collator DataCollatorForLanguageModeling FlaxDataCollatorForLanguageModeling (custom, NumPy-based)
LR scheduler Transformers built-in Custom create_learning_rate_scheduler (rsqrt_decay default)
Multi-device PyTorch DDP jax.pmap with "batch" axis
Metrics Loss, perplexity Loss, accuracy (via compute_metrics)
Checkpointing Automatic via Trainer Not implemented (TensorBoard logging only)
Data modes line_by_line + concatenation Line-by-line tokenization only (no group_texts)

JAX Random Number Handling

The script uses explicit JAX PRNG key management. A root key is created from training_args.seed, then split into per-device dropout_rngs. During training, each step splits the dropout RNG to produce a new key, ensuring reproducible randomness across devices.

Multi-device Parallelism

Training and evaluation steps are parallelized using jax.pmap with the "batch" axis name. The optimizer state is replicated across all local devices using jax_utils.replicate. Data batches are sharded across devices using common_utils.shard. Gradients are averaged across devices using jax.lax.pmean and metrics are summed using jax.lax.psum.

TensorBoard Integration

TensorBoard logging is conditionally enabled: the script checks for is_tensorboard_available() at import time and only logs on the master node (jax.host_id() == 0). Evaluation metrics (loss, accuracy) are written per epoch.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment