Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Microsoft DeepSpeedExamples Multimodal Distributed Training

From Leeroopedia


  1. Principle: Multimodal_Distributed_Training

Metadata

Field Value
Page Type Principle
Title Multimodal_Distributed_Training
Sources Paper: ZeRO (https://arxiv.org/abs/1910.02054), Paper: DeepSpeed-VisualChat (https://arxiv.org/abs/2309.14327)
Domains Distributed_Training, Multimodal
Repository Microsoft/DeepSpeedExamples
Application DeepSpeed-VisualChat
Status Active

Overview

A distributed training technique combining ZeRO optimization with LoRA for efficient multimodal model fine-tuning across multiple GPUs.

Description

Training a multimodal model with billions of parameters (e.g., a LLaMA-2-7B language decoder combined with a ViT-bigG vision encoder) requires careful orchestration of what to train, how to distribute, and how to optimize. DeepSpeed-VisualChat uses a combination of:

  • Selective parameter training -- Only the projection layer, language embeddings, and optional LoRA adapter weights are trained, while the vision encoder and base language decoder remain frozen.
  • ZeRO optimization -- DeepSpeed's ZeRO (Zero Redundancy Optimizer) partitions optimizer states, gradients, and optionally parameters across GPUs to reduce per-GPU memory consumption.
  • Multi-group optimizer -- Different parameter groups receive different learning rates and weight decay settings, allowing fine-grained control over the training dynamics.
  • LoRA (Low-Rank Adaptation) -- Optional low-rank adapters are injected into the language decoder and/or vision encoder for parameter-efficient fine-tuning.

Trainable Parameter Budget

The total trainable parameter count is a small fraction of the full model:

trainable_params = projection_layer_params + lang_embed_params + lora_params
total_params = vis_encoder_params + lang_decoder_params + projection_params + lang_embed_params

trainable_params << total_params

For example, with LLaMA-2-7B and a Perceiver projection:

  • Projection: ~50M parameters (trainable)
  • Language embedding: ~130K parameters per new token (trainable)
  • LoRA adapters (rank 16): ~10M parameters (trainable)
  • Vision encoder (frozen): ~1.8B parameters
  • Language decoder (frozen base): ~7B parameters

This means only ~1% of the total model parameters are updated during training, enabling fine-tuning on modest hardware.

Theoretical Basis

ZeRO Optimization Stages

DeepSpeed-VisualChat supports ZeRO stages 0-3:

Stage What is Partitioned Memory Savings
Stage 0 Nothing (data parallel) Baseline
Stage 1 Optimizer states ~4x reduction in optimizer memory
Stage 2 Optimizer states + gradients ~8x reduction
Stage 3 Optimizer states + gradients + parameters Linear scaling with number of GPUs

For ZeRO Stage 3, special handling is required:

  • HfDeepSpeedConfig must be initialized before model loading to enable parameter sharding during from_pretrained()
  • Parameters must be gathered across ranks before saving or fusing LoRA weights
  • The stage3_param_persistence_threshold (1e4) controls which small parameters remain replicated

Multi-Group Optimizer Configuration

The optimizer uses four parameter groups organized along two axes:

Group Weight Decay Learning Rate Parameters
Group 1 weight_decay Normal LR Non-embedding trainable params without "bias" or "LayerNorm" in name
Group 2 0.0 Normal LR Trainable params with "bias" or "LayerNorm" in name (non-embedding)
Group 3 weight_decay Small LR Embedding-related trainable params without "bias" or "LayerNorm"
Group 4 0.0 Small LR Embedding-related trainable params with "bias" or "LayerNorm"

The small learning rate group (controlled by --learning_rate_pretraining_components) is applied to parameters containing "embed" in their name. This provides a lower learning rate for pre-trained embedding weights to prevent catastrophic forgetting, while the projection layer and LoRA weights receive the full learning rate.

LoRA Integration

LoRA adapters are optionally applied to specific layers:

Language decoder LoRA:
    --lang_lora_dim 16                    # rank of LoRA decomposition
    --lang_lora_module_name model.layers.  # target module scope

Vision encoder LoRA:
    --vis_lora_dim 16
    --vis_lora_module_name encoder.layers.

When --only_optimize_lora is set, all parameters except LoRA weights are frozen in the target module, providing maximum parameter efficiency.

Training Loop

The training loop follows the standard DeepSpeed pattern:

for epoch in range(num_train_epochs):
    for step, batch in enumerate(train_dataloader):
        batch = to_device(batch, device)
        loss = model(images, input_ids, attention_mask, labels, image_num)[0]
        model.backward(loss)    # DeepSpeed handles gradient accumulation
        model.step()            # DeepSpeed handles optimizer step + ZeRO sync

    # Epoch-end: fuse LoRA, save, unfuse LoRA
    model = fuse_lora(model)
    save_hf_format(model, tokenizer, args, f'epoch-{epoch}')
    if args.zero_stage == 3:
        save_zero_three_model(model, global_rank, output_dir, ...)
    model = unfuse_lora(model)

Learning Rate Schedule

The learning rate follows a cosine schedule with warmup:

lr_scheduler = get_scheduler(
    name="cosine",
    optimizer=optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=num_epochs * steps_per_epoch
)

Warmup can be specified as either:

  • An absolute step count (if > 1): --num_warmup_steps 100
  • A ratio of total steps (if <= 1): --num_warmup_steps 0.03

Checkpoint Resumption

The training state is fully recoverable via DeepSpeed checkpoints:

client_state = {
    'random_rng_state': random.getstate(),
    'np_rng_state': np.random.get_state(),
    'torch_rng_state': torch.get_rng_state(),
    'torch_cuda_rng_state': torch.cuda.get_rng_state(),
    'epoch': epoch + 1,
    'best_loss': best_loss,
}
model.save_checkpoint(output_dir, client_state=client_state)

This saves all RNG states to ensure exact reproducibility when resuming.

Key Considerations

  • Mixed precision -- Training uses either FP16 or BF16 (controlled by --precision). FP16 is recommended for typical use; BF16 for larger models to avoid overflow.
  • Gradient clipping -- Gradients are clipped to a max norm of 1.0 (gradient_clipping: 1.0) to prevent training instability.
  • Batch size configuration -- The effective batch size is per_device_batch_size * world_size * gradient_accumulation_steps. DeepSpeed manages gradient accumulation internally.
  • Evaluation frequency -- Evaluation runs once per epoch on the held-out eval split using the evaluation() function with model.eval() and torch.no_grad().
  • TensorBoard logging -- Optional TensorBoard integration is available via --enable_tensorboard.
  • CPU offloading -- ZeRO supports offloading optimizer states and parameters to CPU when GPU memory is insufficient (controlled in get_train_ds_config).

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment