Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Zai org CogVideo SAT Training Execution

From Leeroopedia


Metadata

Field Value
Page Type Principle
Knowledge Sources CogVideo, SwissArmyTransformer
Domains Training, Distributed_Computing
Last Updated 2026-02-10 00:00 GMT

Overview

Technique for executing distributed training of video diffusion models using the SAT framework with DeepSpeed integration.

Description

SAT training execution wraps the model, data, and forward step functions into a DeepSpeed-managed training loop. The process is orchestrated by the training_main function from SwissArmyTransformer, which handles all aspects of distributed training: initialization, data loading, forward/backward passes, gradient accumulation, checkpointing, and optional evaluation.

Training Pipeline

The training execution follows this pipeline:

1. Entry Point and Launch

Training is launched via shell scripts that set up the distributed environment:

  • Single GPU: finetune_single_gpu.sh sets WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 and calls python train_video.py directly.
  • Multi GPU: finetune_multi_gpus.sh uses torchrun --standalone --nproc_per_node=8 to launch distributed processes.

Both scripts pass --base with YAML config files and --seed $RANDOM for stochastic variation.

2. Argument Parsing and Dataset Creation

The train_video.py main block:

  1. Parses arguments via get_args, which loads and merges YAML configs.
  2. Resolves the dataset class from args.data_config["target"] (e.g., data_video.SFTDataset).
  3. Creates a partial function create_dataset_function that will be called by training_main to instantiate the dataset.

3. training_main Orchestration

The SAT training_main function manages the complete training lifecycle:

  1. Model construction: Instantiates SATVideoDiffusionEngine(args).
  2. DeepSpeed wrapping: Wraps the model with a DeepSpeed engine using the config from args.deepspeed_config.
  3. Data loading: Creates the dataset and distributed data loader.
  4. Training loop: Iterates for args.train_iters iterations or args.epochs epochs, calling the custom forward_step_function for each batch.
  5. Evaluation: Periodically calls forward_step_eval at args.eval_interval steps to log validation metrics and generate sample videos.
  6. Checkpointing: Saves model state at args.save_interval steps to args.save directory.
  7. Logging: Reports loss and timing metrics at args.log_interval steps.

4. Custom Forward Step

The custom forward_step function implements the video diffusion training objective:

  1. Data loading: Retrieves the next batch from the data iterator (on model-parallel rank 0).
  2. Broadcasting: Broadcasts batch data across model-parallel ranks via broad_cast_batch to ensure all ranks in a model-parallel group see the same data.
  3. Config saving: On the first iteration (rank 0 only), saves the merged training config as training_config.yaml in the checkpoint directory.
  4. Forward pass: Calls model.shared_step(batch), which:
    • Extracts the video tensor from the batch.
    • Permutes dimensions from [B, T, C, H, W] to [B, C, T, H, W].
    • Encodes the video to latent space via the frozen VAE.
    • For I2V: extracts and noises the first frame, encodes it, and stores as batch["concat_images"].
    • Calls self.forward(x, batch), which invokes loss_fn(model, denoiser, conditioner, x, batch).
  5. Loss return: Returns the scalar loss and loss dictionary.

5. Evaluation Forward Step

The forward_step_eval function extends the training forward step with video generation logging:

  1. Computes the training loss (same as forward_step).
  2. On data-parallel rank 0, calls model.log_video(batch) to generate sample videos using the current model weights.
  3. Saves generated videos as mp4 files and optionally logs them to wandb.

Distributed Communication

The broad_cast_batch function handles model-parallel data distribution:

  1. Rank 0 of each model-parallel group loads and preprocesses the batch.
  2. Batch tensor shapes are broadcast via torch.distributed.broadcast_object_list.
  3. Non-rank-0 processes allocate zero tensors of the correct shapes.
  4. Batch tensors (mp4, fps, num_frames) are broadcast via torch.distributed.broadcast.
  5. Text data (txt) is broadcast as a Python object list.

Usage

Use to launch SAT-based CogVideoX training. The standard invocation is via the provided shell scripts:

Single GPU

cd sat/
bash finetune_single_gpu.sh

Multi GPU (8 GPUs)

cd sat/
bash finetune_multi_gpus.sh

Custom Launch

WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 python train_video.py \
    --base configs/cogvideox_2b_lora.yaml configs/sft.yaml \
    --seed 42

Theoretical Basis

DeepSpeed ZeRO Optimization

DeepSpeed provides ZeRO (Zero Redundancy Optimizer) state partitioning that distributes memory consumption across data-parallel ranks. The SAT pipeline typically uses ZeRO Stage 2, which partitions both optimizer states and gradients:

  • Stage 1: Partitions optimizer states (e.g., Adam momentum and variance) across ranks. Each rank only stores 1/N of the optimizer state.
  • Stage 2: Additionally partitions gradients. After the backward pass, gradients are reduced and scattered so each rank only stores its partition.

This reduces per-GPU memory by approximately (2x + 2x) / N where x is the model size and N is the number of data-parallel ranks, compared to standard data parallelism which replicates the full optimizer state on every rank.

Diffusion Training Objective

The forward step implements the standard denoising diffusion training objective: given a clean latent z_0, sample a noise level sigma from the discretization schedule, add noise to produce z_t = z_0 + sigma * epsilon, and train the model to predict either the noise epsilon or the clean sample z_0 (depending on the parameterization). The loss is computed as a weighted mean squared error between the prediction and the target, with weights determined by the EpsWeighting schedule.

Model Parallelism and Data Broadcasting

When model parallelism is used (model split across multiple GPUs), each model-parallel group processes the same data but different model partitions. The broad_cast_batch function ensures all ranks in a model-parallel group receive identical input data, while different data-parallel groups receive different batches for gradient diversity.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment