Principle:Zai org CogVideo SAT Training Execution
Metadata
| Field | Value |
|---|---|
| Page Type | Principle |
| Knowledge Sources | CogVideo, SwissArmyTransformer |
| Domains | Training, Distributed_Computing |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Technique for executing distributed training of video diffusion models using the SAT framework with DeepSpeed integration.
Description
SAT training execution wraps the model, data, and forward step functions into a DeepSpeed-managed training loop. The process is orchestrated by the training_main function from SwissArmyTransformer, which handles all aspects of distributed training: initialization, data loading, forward/backward passes, gradient accumulation, checkpointing, and optional evaluation.
Training Pipeline
The training execution follows this pipeline:
1. Entry Point and Launch
Training is launched via shell scripts that set up the distributed environment:
- Single GPU:
finetune_single_gpu.shsetsWORLD_SIZE=1 RANK=0 LOCAL_RANK=0and callspython train_video.pydirectly. - Multi GPU:
finetune_multi_gpus.shusestorchrun --standalone --nproc_per_node=8to launch distributed processes.
Both scripts pass --base with YAML config files and --seed $RANDOM for stochastic variation.
2. Argument Parsing and Dataset Creation
The train_video.py main block:
- Parses arguments via
get_args, which loads and merges YAML configs. - Resolves the dataset class from
args.data_config["target"](e.g.,data_video.SFTDataset). - Creates a partial function
create_dataset_functionthat will be called bytraining_mainto instantiate the dataset.
3. training_main Orchestration
The SAT training_main function manages the complete training lifecycle:
- Model construction: Instantiates
SATVideoDiffusionEngine(args). - DeepSpeed wrapping: Wraps the model with a DeepSpeed engine using the config from
args.deepspeed_config. - Data loading: Creates the dataset and distributed data loader.
- Training loop: Iterates for
args.train_itersiterations orargs.epochsepochs, calling the customforward_step_functionfor each batch. - Evaluation: Periodically calls
forward_step_evalatargs.eval_intervalsteps to log validation metrics and generate sample videos. - Checkpointing: Saves model state at
args.save_intervalsteps toargs.savedirectory. - Logging: Reports loss and timing metrics at
args.log_intervalsteps.
4. Custom Forward Step
The custom forward_step function implements the video diffusion training objective:
- Data loading: Retrieves the next batch from the data iterator (on model-parallel rank 0).
- Broadcasting: Broadcasts batch data across model-parallel ranks via
broad_cast_batchto ensure all ranks in a model-parallel group see the same data. - Config saving: On the first iteration (rank 0 only), saves the merged training config as
training_config.yamlin the checkpoint directory. - Forward pass: Calls
model.shared_step(batch), which:- Extracts the video tensor from the batch.
- Permutes dimensions from
[B, T, C, H, W]to[B, C, T, H, W]. - Encodes the video to latent space via the frozen VAE.
- For I2V: extracts and noises the first frame, encodes it, and stores as
batch["concat_images"]. - Calls
self.forward(x, batch), which invokesloss_fn(model, denoiser, conditioner, x, batch).
- Loss return: Returns the scalar loss and loss dictionary.
5. Evaluation Forward Step
The forward_step_eval function extends the training forward step with video generation logging:
- Computes the training loss (same as
forward_step). - On data-parallel rank 0, calls
model.log_video(batch)to generate sample videos using the current model weights. - Saves generated videos as mp4 files and optionally logs them to wandb.
Distributed Communication
The broad_cast_batch function handles model-parallel data distribution:
- Rank 0 of each model-parallel group loads and preprocesses the batch.
- Batch tensor shapes are broadcast via
torch.distributed.broadcast_object_list. - Non-rank-0 processes allocate zero tensors of the correct shapes.
- Batch tensors (
mp4,fps,num_frames) are broadcast viatorch.distributed.broadcast. - Text data (
txt) is broadcast as a Python object list.
Usage
Use to launch SAT-based CogVideoX training. The standard invocation is via the provided shell scripts:
Single GPU
cd sat/
bash finetune_single_gpu.sh
Multi GPU (8 GPUs)
cd sat/
bash finetune_multi_gpus.sh
Custom Launch
WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 python train_video.py \
--base configs/cogvideox_2b_lora.yaml configs/sft.yaml \
--seed 42
Theoretical Basis
DeepSpeed ZeRO Optimization
DeepSpeed provides ZeRO (Zero Redundancy Optimizer) state partitioning that distributes memory consumption across data-parallel ranks. The SAT pipeline typically uses ZeRO Stage 2, which partitions both optimizer states and gradients:
- Stage 1: Partitions optimizer states (e.g., Adam momentum and variance) across ranks. Each rank only stores 1/N of the optimizer state.
- Stage 2: Additionally partitions gradients. After the backward pass, gradients are reduced and scattered so each rank only stores its partition.
This reduces per-GPU memory by approximately (2x + 2x) / N where x is the model size and N is the number of data-parallel ranks, compared to standard data parallelism which replicates the full optimizer state on every rank.
Diffusion Training Objective
The forward step implements the standard denoising diffusion training objective: given a clean latent z_0, sample a noise level sigma from the discretization schedule, add noise to produce z_t = z_0 + sigma * epsilon, and train the model to predict either the noise epsilon or the clean sample z_0 (depending on the parameterization). The loss is computed as a weighted mean squared error between the prediction and the target, with weights determined by the EpsWeighting schedule.
Model Parallelism and Data Broadcasting
When model parallelism is used (model split across multiple GPUs), each model-parallel group processes the same data but different model partitions. The broad_cast_batch function ensures all ranks in a model-parallel group receive identical input data, while different data-parallel groups receive different batches for gradient diversity.