Implementation:Zai org CogVideo SAT Training Main
Metadata
| Field | Value |
|---|---|
| Page Type | Implementation (Wrapper Doc) |
| Knowledge Sources | CogVideo, SwissArmyTransformer |
| Domains | Training, Distributed_Computing |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Concrete tool for executing SAT distributed training using the SwissArmyTransformer training_main function, with custom forward step functions for the CogVideoX video diffusion objective.
Description
The train_video.py module provides the entry point for SAT-based CogVideoX training. It defines custom forward_step and forward_step_eval functions that implement the video diffusion training and evaluation logic, then delegates the complete training lifecycle to SAT's training_main.
The module's main block:
- Handles OpenMPI environment variable mapping (
OMPI_COMM_WORLD_LOCAL_RANKtoLOCAL_RANK, etc.). - Parses CLI arguments and YAML configs via
get_args. - Resolves the dataset class from
args.data_config["target"]and creates a partialcreate_dataset_function. - Loads the raw YAML configs for logging purposes.
- Calls
training_mainwith the model class, forward step functions, and dataset creation function.
The forward_step function:
- Loads the next batch from the data iterator on model-parallel rank 0.
- Moves all tensor values in the batch to CUDA.
- On rank 0, saves the merged training config as
training_config.yaml(first iteration only). - Broadcasts the batch across model-parallel ranks via
broad_cast_batch. - Calls
model.shared_step(batch)to compute the diffusion loss. - Returns the scalar loss and loss dictionary.
The forward_step_eval function extends the training step with video logging:
- Processes multi-view batch reshaping if needed (6D tensors).
- Calls
log_videoon data-parallel rank 0 to generate and save sample videos. - Computes and returns the validation loss.
Usage
This module is the standard entry point for all SAT-based CogVideoX training. It is invoked via the shell scripts finetune_single_gpu.sh or finetune_multi_gpus.sh, or directly via torchrun.
Code Reference
Source Location
sat/train_video.py:L205-240(main block)sat/train_video.py:L178-202(forward_step)sat/train_video.py:L144-175(forward_step_eval)sat/train_video.py:L115-141(broad_cast_batch)
Signature
from sat.training.deepspeed_training import training_main
# Main training invocation
training_main(
args,
model_cls=SATVideoDiffusionEngine,
forward_step_function=partial(forward_step, data_class=data_class),
forward_step_eval=partial(
forward_step_eval,
data_class=data_class,
only_log_video_latents=args.only_log_video_latents,
),
create_dataset_function=create_dataset_function,
)
def forward_step(data_iterator, model, args, timers, data_class=None) -> Tuple[torch.Tensor, Dict]:
"""
Custom forward step computing video diffusion loss.
Args:
data_iterator: Iterator yielding training batches.
model: SATVideoDiffusionEngine instance.
args: Parsed training arguments.
timers: SAT timer utilities.
data_class: Dataset class (unused, for compatibility).
Returns:
loss: Scalar training loss tensor.
loss_dict: Dictionary with loss breakdown.
"""
def forward_step_eval(data_iterator, model, args, timers,
only_log_video_latents=False, data_class=None) -> Tuple[torch.Tensor, Dict]:
"""
Evaluation forward step with video logging.
Additionally generates sample videos via model.log_video() and saves them
as mp4 files (and optionally to wandb) on data-parallel rank 0.
"""
Import
from sat.training.deepspeed_training import training_main
I/O Contract
Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
args |
argparse.Namespace |
Yes | Complete training configuration from get_args.
|
model_cls |
type | Yes | SATVideoDiffusionEngine class (not an instance).
|
forward_step_function |
Callable | Yes | Custom forward step computing the diffusion loss. |
forward_step_eval |
Callable | No | Evaluation forward step with video logging. |
create_dataset_function |
Callable | Yes | Factory function for dataset creation, typically partial(data_class.create_dataset_function, **args.data_config["params"]).
|
Key training args:
| Argument | Default | Description |
|---|---|---|
args.train_iters |
10000 | Number of training iterations (if epochs not set). |
args.epochs |
None | Number of training epochs (mutually exclusive with train_iters). |
args.batch_size |
From DeepSpeed config | Micro batch size per GPU per step. |
args.lr |
From DeepSpeed config | Learning rate. |
args.gradient_accumulation_steps |
From DeepSpeed config | Number of micro-batches before optimizer step. |
args.save |
From YAML | Checkpoint output directory. |
args.save_interval |
From YAML | Steps between checkpoint saves. |
args.eval_interval |
From YAML | Steps between evaluation runs. |
args.zero_stage |
From YAML | DeepSpeed ZeRO optimization stage (0, 1, 2, or 3). |
Outputs
| Output | Type | Description |
|---|---|---|
| Checkpoints | Directory | Model checkpoints saved to args.save at every save_interval steps. Contains model weights, optimizer state, and training state.
|
| Training config | YAML file | training_config.yaml saved in the checkpoint directory on first iteration.
|
| Video samples | mp4 files | Generated sample videos saved during evaluation in args.save/video/ directory.
|
| Video latents | pt files | If only_log_video_latents=True, latent tensors saved as latent.pt instead of decoded videos.
|
Usage Examples
Single GPU Launch
# From sat/ directory
WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1 \
python train_video.py \
--base configs/cogvideox_2b_lora.yaml configs/sft.yaml \
--seed $RANDOM
Multi GPU Launch (8 GPUs)
# From sat/ directory
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
torchrun --standalone --nproc_per_node=8 \
train_video.py \
--base configs/cogvideox_5b_lora.yaml configs/sft.yaml \
--seed $RANDOM
External Dependencies
- sat.training.deepspeed_training: Provides
training_mainfor the complete distributed training lifecycle. - deepspeed: Distributed training engine with ZeRO optimization.
- torch.distributed: PyTorch distributed communication primitives for broadcasting.
- wandb: Optional experiment tracking and video logging.
- imageio: Saving generated video samples as mp4 files.
- sat.mpu: Model parallelism utilities for rank management and group communication.