Principle:Microsoft DeepSpeedExamples Multimodal Distributed Training
- Principle: Multimodal_Distributed_Training
Metadata
| Field | Value |
|---|---|
| Page Type | Principle |
| Title | Multimodal_Distributed_Training |
| Sources | Paper: ZeRO (https://arxiv.org/abs/1910.02054), Paper: DeepSpeed-VisualChat (https://arxiv.org/abs/2309.14327) |
| Domains | Distributed_Training, Multimodal |
| Repository | Microsoft/DeepSpeedExamples |
| Application | DeepSpeed-VisualChat |
| Status | Active |
Overview
A distributed training technique combining ZeRO optimization with LoRA for efficient multimodal model fine-tuning across multiple GPUs.
Description
Training a multimodal model with billions of parameters (e.g., a LLaMA-2-7B language decoder combined with a ViT-bigG vision encoder) requires careful orchestration of what to train, how to distribute, and how to optimize. DeepSpeed-VisualChat uses a combination of:
- Selective parameter training -- Only the projection layer, language embeddings, and optional LoRA adapter weights are trained, while the vision encoder and base language decoder remain frozen.
- ZeRO optimization -- DeepSpeed's ZeRO (Zero Redundancy Optimizer) partitions optimizer states, gradients, and optionally parameters across GPUs to reduce per-GPU memory consumption.
- Multi-group optimizer -- Different parameter groups receive different learning rates and weight decay settings, allowing fine-grained control over the training dynamics.
- LoRA (Low-Rank Adaptation) -- Optional low-rank adapters are injected into the language decoder and/or vision encoder for parameter-efficient fine-tuning.
Trainable Parameter Budget
The total trainable parameter count is a small fraction of the full model:
trainable_params = projection_layer_params + lang_embed_params + lora_params
total_params = vis_encoder_params + lang_decoder_params + projection_params + lang_embed_params
trainable_params << total_params
For example, with LLaMA-2-7B and a Perceiver projection:
- Projection: ~50M parameters (trainable)
- Language embedding: ~130K parameters per new token (trainable)
- LoRA adapters (rank 16): ~10M parameters (trainable)
- Vision encoder (frozen): ~1.8B parameters
- Language decoder (frozen base): ~7B parameters
This means only ~1% of the total model parameters are updated during training, enabling fine-tuning on modest hardware.
Theoretical Basis
ZeRO Optimization Stages
DeepSpeed-VisualChat supports ZeRO stages 0-3:
| Stage | What is Partitioned | Memory Savings |
|---|---|---|
| Stage 0 | Nothing (data parallel) | Baseline |
| Stage 1 | Optimizer states | ~4x reduction in optimizer memory |
| Stage 2 | Optimizer states + gradients | ~8x reduction |
| Stage 3 | Optimizer states + gradients + parameters | Linear scaling with number of GPUs |
For ZeRO Stage 3, special handling is required:
HfDeepSpeedConfigmust be initialized before model loading to enable parameter sharding duringfrom_pretrained()- Parameters must be gathered across ranks before saving or fusing LoRA weights
- The
stage3_param_persistence_threshold(1e4) controls which small parameters remain replicated
Multi-Group Optimizer Configuration
The optimizer uses four parameter groups organized along two axes:
| Group | Weight Decay | Learning Rate | Parameters |
|---|---|---|---|
| Group 1 | weight_decay |
Normal LR | Non-embedding trainable params without "bias" or "LayerNorm" in name |
| Group 2 | 0.0 | Normal LR | Trainable params with "bias" or "LayerNorm" in name (non-embedding) |
| Group 3 | weight_decay |
Small LR | Embedding-related trainable params without "bias" or "LayerNorm" |
| Group 4 | 0.0 | Small LR | Embedding-related trainable params with "bias" or "LayerNorm" |
The small learning rate group (controlled by --learning_rate_pretraining_components) is applied to parameters containing "embed" in their name. This provides a lower learning rate for pre-trained embedding weights to prevent catastrophic forgetting, while the projection layer and LoRA weights receive the full learning rate.
LoRA Integration
LoRA adapters are optionally applied to specific layers:
Language decoder LoRA:
--lang_lora_dim 16 # rank of LoRA decomposition
--lang_lora_module_name model.layers. # target module scope
Vision encoder LoRA:
--vis_lora_dim 16
--vis_lora_module_name encoder.layers.
When --only_optimize_lora is set, all parameters except LoRA weights are frozen in the target module, providing maximum parameter efficiency.
Training Loop
The training loop follows the standard DeepSpeed pattern:
for epoch in range(num_train_epochs):
for step, batch in enumerate(train_dataloader):
batch = to_device(batch, device)
loss = model(images, input_ids, attention_mask, labels, image_num)[0]
model.backward(loss) # DeepSpeed handles gradient accumulation
model.step() # DeepSpeed handles optimizer step + ZeRO sync
# Epoch-end: fuse LoRA, save, unfuse LoRA
model = fuse_lora(model)
save_hf_format(model, tokenizer, args, f'epoch-{epoch}')
if args.zero_stage == 3:
save_zero_three_model(model, global_rank, output_dir, ...)
model = unfuse_lora(model)
Learning Rate Schedule
The learning rate follows a cosine schedule with warmup:
lr_scheduler = get_scheduler(
name="cosine",
optimizer=optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=num_epochs * steps_per_epoch
)
Warmup can be specified as either:
- An absolute step count (if > 1):
--num_warmup_steps 100 - A ratio of total steps (if <= 1):
--num_warmup_steps 0.03
Checkpoint Resumption
The training state is fully recoverable via DeepSpeed checkpoints:
client_state = {
'random_rng_state': random.getstate(),
'np_rng_state': np.random.get_state(),
'torch_rng_state': torch.get_rng_state(),
'torch_cuda_rng_state': torch.cuda.get_rng_state(),
'epoch': epoch + 1,
'best_loss': best_loss,
}
model.save_checkpoint(output_dir, client_state=client_state)
This saves all RNG states to ensure exact reproducibility when resuming.
Key Considerations
- Mixed precision -- Training uses either FP16 or BF16 (controlled by
--precision). FP16 is recommended for typical use; BF16 for larger models to avoid overflow. - Gradient clipping -- Gradients are clipped to a max norm of 1.0 (
gradient_clipping: 1.0) to prevent training instability. - Batch size configuration -- The effective batch size is
per_device_batch_size * world_size * gradient_accumulation_steps. DeepSpeed manages gradient accumulation internally. - Evaluation frequency -- Evaluation runs once per epoch on the held-out eval split using the
evaluation()function withmodel.eval()andtorch.no_grad(). - TensorBoard logging -- Optional TensorBoard integration is available via
--enable_tensorboard. - CPU offloading -- ZeRO supports offloading optimizer states and parameters to CPU when GPU memory is insufficient (controlled in
get_train_ds_config).
Related Pages
- Implementation:Microsoft_DeepSpeedExamples_DeepSpeed_Initialize_VisualChat -- The concrete training initialization
- Principle:Microsoft_DeepSpeedExamples_Multimodal_Model_Composition -- The model being trained
- Principle:Microsoft_DeepSpeedExamples_Multi_Dataset_VQA_Preparation -- The data fed into training
- Principle:Microsoft_DeepSpeedExamples_LoRA_Fusion_And_Export -- Post-training LoRA fusion and export