Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Zai org CogVideo SAT Training Main

From Leeroopedia


Metadata

Field Value
Page Type Implementation (Wrapper Doc)
Knowledge Sources CogVideo, SwissArmyTransformer
Domains Training, Distributed_Computing
Last Updated 2026-02-10 00:00 GMT

Overview

Concrete tool for executing SAT distributed training using the SwissArmyTransformer training_main function, with custom forward step functions for the CogVideoX video diffusion objective.

Description

The train_video.py module provides the entry point for SAT-based CogVideoX training. It defines custom forward_step and forward_step_eval functions that implement the video diffusion training and evaluation logic, then delegates the complete training lifecycle to SAT's training_main.

The module's main block:

  1. Handles OpenMPI environment variable mapping (OMPI_COMM_WORLD_LOCAL_RANK to LOCAL_RANK, etc.).
  2. Parses CLI arguments and YAML configs via get_args.
  3. Resolves the dataset class from args.data_config["target"] and creates a partial create_dataset_function.
  4. Loads the raw YAML configs for logging purposes.
  5. Calls training_main with the model class, forward step functions, and dataset creation function.

The forward_step function:

  1. Loads the next batch from the data iterator on model-parallel rank 0.
  2. Moves all tensor values in the batch to CUDA.
  3. On rank 0, saves the merged training config as training_config.yaml (first iteration only).
  4. Broadcasts the batch across model-parallel ranks via broad_cast_batch.
  5. Calls model.shared_step(batch) to compute the diffusion loss.
  6. Returns the scalar loss and loss dictionary.

The forward_step_eval function extends the training step with video logging:

  1. Processes multi-view batch reshaping if needed (6D tensors).
  2. Calls log_video on data-parallel rank 0 to generate and save sample videos.
  3. Computes and returns the validation loss.

Usage

This module is the standard entry point for all SAT-based CogVideoX training. It is invoked via the shell scripts finetune_single_gpu.sh or finetune_multi_gpus.sh, or directly via torchrun.

Code Reference

Source Location

  • sat/train_video.py:L205-240 (main block)
  • sat/train_video.py:L178-202 (forward_step)
  • sat/train_video.py:L144-175 (forward_step_eval)
  • sat/train_video.py:L115-141 (broad_cast_batch)

Signature

from sat.training.deepspeed_training import training_main

# Main training invocation
training_main(
    args,
    model_cls=SATVideoDiffusionEngine,
    forward_step_function=partial(forward_step, data_class=data_class),
    forward_step_eval=partial(
        forward_step_eval,
        data_class=data_class,
        only_log_video_latents=args.only_log_video_latents,
    ),
    create_dataset_function=create_dataset_function,
)

def forward_step(data_iterator, model, args, timers, data_class=None) -> Tuple[torch.Tensor, Dict]:
    """
    Custom forward step computing video diffusion loss.

    Args:
        data_iterator: Iterator yielding training batches.
        model: SATVideoDiffusionEngine instance.
        args: Parsed training arguments.
        timers: SAT timer utilities.
        data_class: Dataset class (unused, for compatibility).

    Returns:
        loss: Scalar training loss tensor.
        loss_dict: Dictionary with loss breakdown.
    """

def forward_step_eval(data_iterator, model, args, timers,
                      only_log_video_latents=False, data_class=None) -> Tuple[torch.Tensor, Dict]:
    """
    Evaluation forward step with video logging.

    Additionally generates sample videos via model.log_video() and saves them
    as mp4 files (and optionally to wandb) on data-parallel rank 0.
    """

Import

from sat.training.deepspeed_training import training_main

I/O Contract

Inputs

Parameter Type Required Description
args argparse.Namespace Yes Complete training configuration from get_args.
model_cls type Yes SATVideoDiffusionEngine class (not an instance).
forward_step_function Callable Yes Custom forward step computing the diffusion loss.
forward_step_eval Callable No Evaluation forward step with video logging.
create_dataset_function Callable Yes Factory function for dataset creation, typically partial(data_class.create_dataset_function, **args.data_config["params"]).

Key training args:

Argument Default Description
args.train_iters 10000 Number of training iterations (if epochs not set).
args.epochs None Number of training epochs (mutually exclusive with train_iters).
args.batch_size From DeepSpeed config Micro batch size per GPU per step.
args.lr From DeepSpeed config Learning rate.
args.gradient_accumulation_steps From DeepSpeed config Number of micro-batches before optimizer step.
args.save From YAML Checkpoint output directory.
args.save_interval From YAML Steps between checkpoint saves.
args.eval_interval From YAML Steps between evaluation runs.
args.zero_stage From YAML DeepSpeed ZeRO optimization stage (0, 1, 2, or 3).

Outputs

Output Type Description
Checkpoints Directory Model checkpoints saved to args.save at every save_interval steps. Contains model weights, optimizer state, and training state.
Training config YAML file training_config.yaml saved in the checkpoint directory on first iteration.
Video samples mp4 files Generated sample videos saved during evaluation in args.save/video/ directory.
Video latents pt files If only_log_video_latents=True, latent tensors saved as latent.pt instead of decoded videos.

Usage Examples

Single GPU Launch

# From sat/ directory
WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1 \
    python train_video.py \
    --base configs/cogvideox_2b_lora.yaml configs/sft.yaml \
    --seed $RANDOM

Multi GPU Launch (8 GPUs)

# From sat/ directory
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
    torchrun --standalone --nproc_per_node=8 \
    train_video.py \
    --base configs/cogvideox_5b_lora.yaml configs/sft.yaml \
    --seed $RANDOM

External Dependencies

  • sat.training.deepspeed_training: Provides training_main for the complete distributed training lifecycle.
  • deepspeed: Distributed training engine with ZeRO optimization.
  • torch.distributed: PyTorch distributed communication primitives for broadcasting.
  • wandb: Optional experiment tracking and video logging.
  • imageio: Saving generated video samples as mp4 files.
  • sat.mpu: Model parallelism utilities for rank management and group communication.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment