Implementation:Microsoft LoRA Run MLM Flax
| Knowledge Sources | |
|---|---|
| Domains | NLP, Language_Modeling, JAX_Flax |
| Last Updated | 2026-02-10 06:00 GMT |
Overview
HuggingFace Transformers example script for fine-tuning masked language models using the JAX/Flax backend with whole word masking support.
Description
run_mlm_flax.py fine-tunes masked language models using JAX/Flax instead of PyTorch. Unlike run_mlm.py which uses the HuggingFace Trainer API, this script implements a custom JAX training loop with explicit gradient computation, jax.pmap for multi-device parallelism, and a configurable learning rate scheduler. It uses FlaxBertForMaskedLM as its model class and includes a custom FlaxDataCollatorForLanguageModeling data collator implemented with NumPy operations. The script supports whole word masking for Chinese text via optional reference files (train_ref_file, validation_ref_file). TensorBoard logging is supported via flax.metrics.tensorboard.SummaryWriter. This script is part of the modified Transformers fork used by Microsoft LoRA for NLU experiments.
Usage
Use this script when fine-tuning a masked language model with JAX/Flax for hardware acceleration on TPUs or GPUs with JAX support. Supports both local files (CSV, JSON, TXT) and HuggingFace dataset hub datasets. The script is designed for multi-device training via jax.pmap with automatic data sharding across devices. Integrated with the LoRA-modified Transformers fork.
Code Reference
Source Location
- Repository: Microsoft_LoRA
- File: examples/NLU/examples/language-modeling/run_mlm_flax.py
- Lines: 1-661
Signature
# Script entry point via __main__ block with HfArgumentParser
# Key dataclasses:
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default=None)
model_type: Optional[str] = field(default=None)
config_name: Optional[str] = field(default=None)
tokenizer_name: Optional[str] = field(default=None)
cache_dir: Optional[str] = field(default=None)
use_fast_tokenizer: bool = field(default=True)
@dataclass
class DataTrainingArguments:
dataset_name: Optional[str] = field(default=None)
dataset_config_name: Optional[str] = field(default=None)
train_file: Optional[str] = field(default=None)
validation_file: Optional[str] = field(default=None)
train_ref_file: Optional[str] = field(default=None)
validation_ref_file: Optional[str] = field(default=None)
overwrite_cache: bool = field(default=False)
validation_split_percentage: Optional[int] = field(default=5)
max_seq_length: Optional[int] = field(default=None)
preprocessing_num_workers: Optional[int] = field(default=None)
mlm_probability: float = field(default=0.15)
pad_to_max_length: bool = field(default=False)
# Custom Flax data collator:
@dataclass
class FlaxDataCollatorForLanguageModeling:
tokenizer: PreTrainedTokenizerBase
mlm: bool = True
mlm_probability: float = 0.15
# Custom helper functions:
def create_learning_rate_scheduler(...)
def compute_metrics(logits, labels, weights, label_smoothing=0.0)
def accuracy(logits, targets, weights=None)
def cross_entropy(logits, targets, weights=None, label_smoothing=0.0)
def training_step(optimizer, batch, dropout_rng)
def eval_step(params, batch)
def generate_batch_splits(samples_idx, batch_size)
Import
# Script is run directly, not imported
python examples/NLU/examples/language-modeling/run_mlm_flax.py \
--model_name_or_path bert-base-cased \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--output_dir /tmp/test-mlm-flax
Key Components
Model Loading
The script uses FlaxBertForMaskedLM (not AutoModelForMaskedLM) to load a BERT model in the Flax framework. The model is instantiated with explicit parameters including dtype=jnp.float32, input_shape derived from train_batch_size and config.max_position_embeddings, a seed for reproducibility, and dropout_rate=0.1. The model is hardcoded to load from "bert-base-cased" regardless of the model_name_or_path argument for config/tokenizer loading.
FlaxDataCollatorForLanguageModeling
A custom data collator class that operates with NumPy arrays instead of PyTorch tensors. It implements the standard BERT masking strategy:
- 80% of selected tokens are replaced with the [MASK] token
- 10% are replaced with a random token from the vocabulary
- 10% are kept unchanged
- Special tokens are never masked (protected by special_tokens_mask)
- The collator pads examples using tokenizer.pad with pad_to_multiple_of support
- Masking uses np.random.binomial for probabilistic token selection
Learning Rate Scheduler
create_learning_rate_scheduler implements a composable learning rate schedule specified as a string of factors separated by *. Supported factors:
- constant: base learning rate
- linear_warmup: linear warmup until warmup_steps
- rsqrt_decay: divide by sqrt(max(step, warmup_steps))
- rsqrt_normalized_decay: normalized rsqrt decay
- decay_every: step decay every k steps
- cosine_decay: cyclic cosine decay
The default schedule uses "constant * linear_warmup * rsqrt_decay".
Custom Training Loop
Instead of using the HuggingFace Trainer, this script implements a JAX-native training loop:
- training_step: computes loss via cross_entropy with label smoothing support, computes gradients via jax.value_and_grad, averages gradients across devices with jax.lax.pmean, and applies the gradient update with the learning rate scheduler
- eval_step: computes loss and accuracy metrics via compute_metrics, aggregated across devices with jax.lax.psum
- Both steps are parallelized across devices using jax.pmap with the "batch" axis name
- The optimizer is replicated across all devices with jax_utils.replicate
- Each epoch shuffles training indices with jax.random.permutation and generates batch splits
Optimizer
The script uses Flax's Adam optimizer (flax.optim.Adam), configured with learning_rate, weight_decay, adam_beta1, and adam_beta2 from TrainingArguments. The optimizer is created from the model params and replicated across devices.
Loss Functions
cross_entropy: computes cross-entropy loss with optional label smoothing. Uses common_utils.onehot for soft target construction and jax.nn.log_softmax for numerically stable computation. Only tokens with labels > 0 (non-padding, non-special) contribute to the loss.
accuracy: computes weighted accuracy by comparing argmax predictions against targets, weighted by a token mask.
Batch Generation
generate_batch_splits splits shuffled sample indices into batch-sized chunks. Samples that do not fill a complete batch are dropped (the remainder is discarded).
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model_name_or_path | str | No* | Pretrained model name or path for config/tokenizer loading |
| model_type | str | No* | Model type for training from scratch (e.g., bert) |
| dataset_name | str | No** | HuggingFace dataset name (alternative to train_file) |
| dataset_config_name | str | No | Configuration name for the HuggingFace dataset |
| train_file | str | No** | Path to training text file (CSV, JSON, or TXT; alternative to dataset_name) |
| validation_file | str | No | Path to validation text file |
| train_ref_file | str | No | Reference file for whole word masking in Chinese (training) |
| validation_ref_file | str | No | Reference file for whole word masking in Chinese (validation) |
| max_seq_length | int | No | Maximum sequence length after tokenization (defaults to model max) |
| mlm_probability | float | No | Ratio of tokens to mask (default: 0.15) |
| pad_to_max_length | bool | No | Pad all samples to max_seq_length (default: False) |
| output_dir | str | Yes | Directory to save logs and TensorBoard events |
| train_batch_size | int | No | Per-device training batch size |
| eval_batch_size | int | No | Per-device evaluation batch size |
| num_train_epochs | int | No | Number of training epochs |
| learning_rate | float | No | Base learning rate for Adam optimizer |
| warmup_steps | int | No | Number of warmup steps for learning rate scheduler |
| validation_split_percentage | int | No | Percentage of train set used as validation if no validation split exists (default: 5) |
| preprocessing_num_workers | int | No | Number of processes for data preprocessing |
| overwrite_cache | bool | No | Whether to overwrite cached preprocessed datasets (default: False) |
* Either model_name_or_path or model_type must be provided.
** Either dataset_name or train_file/validation_file must be provided.
Outputs
| Name | Type | Description |
|---|---|---|
| TensorBoard logs | Events | Training loss, eval loss, and accuracy written to output_dir/logs/ (if TensorBoard available) |
| console metrics | Text | Per-epoch loss and accuracy printed to stdout via tqdm progress bars |
Usage Examples
Fine-tune BERT on WikiText with Flax
python examples/NLU/examples/language-modeling/run_mlm_flax.py \
--model_name_or_path bert-base-cased \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--do_train \
--do_eval \
--output_dir ./output/mlm-flax-bert \
--train_batch_size 32 \
--num_train_epochs 3 \
--learning_rate 5e-5
Fine-tune on Custom Text Files
python examples/NLU/examples/language-modeling/run_mlm_flax.py \
--model_name_or_path bert-base-cased \
--train_file ./data/train.txt \
--validation_file ./data/valid.txt \
--do_train \
--do_eval \
--max_seq_length 128 \
--mlm_probability 0.15 \
--output_dir ./output/mlm-flax-custom \
--overwrite_output_dir
Load Arguments from JSON
python examples/NLU/examples/language-modeling/run_mlm_flax.py config.json
Internal Details
Differences from run_mlm.py
| Feature | run_mlm.py | run_mlm_flax.py |
|---|---|---|
| Framework | PyTorch | JAX/Flax |
| Training loop | HuggingFace Trainer | Custom JAX loop with jax.pmap |
| Model class | AutoModelForMaskedLM | FlaxBertForMaskedLM |
| Optimizer | AdamW (PyTorch) | flax.optim.Adam |
| Data collator | DataCollatorForLanguageModeling | FlaxDataCollatorForLanguageModeling (custom, NumPy-based) |
| LR scheduler | Transformers built-in | Custom create_learning_rate_scheduler (rsqrt_decay default) |
| Multi-device | PyTorch DDP | jax.pmap with "batch" axis |
| Metrics | Loss, perplexity | Loss, accuracy (via compute_metrics) |
| Checkpointing | Automatic via Trainer | Not implemented (TensorBoard logging only) |
| Data modes | line_by_line + concatenation | Line-by-line tokenization only (no group_texts) |
JAX Random Number Handling
The script uses explicit JAX PRNG key management. A root key is created from training_args.seed, then split into per-device dropout_rngs. During training, each step splits the dropout RNG to produce a new key, ensuring reproducible randomness across devices.
Multi-device Parallelism
Training and evaluation steps are parallelized using jax.pmap with the "batch" axis name. The optimizer state is replicated across all local devices using jax_utils.replicate. Data batches are sharded across devices using common_utils.shard. Gradients are averaged across devices using jax.lax.pmean and metrics are summed using jax.lax.psum.
TensorBoard Integration
TensorBoard logging is conditionally enabled: the script checks for is_tensorboard_available() at import time and only logs on the master node (jax.host_id() == 0). Evaluation metrics (loss, accuracy) are written per epoch.