Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Zai org CogVideo SAT Get Args

From Leeroopedia


Metadata

Field Value
Page Type Implementation (API Doc)
Knowledge Sources CogVideo
Domains Configuration, Training_Infrastructure
Last Updated 2026-02-10 00:00 GMT

Overview

Concrete tool for parsing YAML configurations and CLI arguments for SAT training provided by the CogVideo SAT module.

Description

The get_args function in sat/arguments.py serves as the unified argument parsing entry point for the SAT training pipeline. It performs the following steps:

  1. Argument group registration: Registers argument groups for model configuration (--base, --model-parallel-size, --device), sampling configuration (--sampling-num-frames, --latent-channels, --output-dir), training arguments (from SAT's add_training_args), evaluation arguments, data arguments, and DeepSpeed arguments.
  2. CLI parsing: Parses all command-line arguments using argparse.
  3. YAML merging: Calls process_config_to_args, which loads all YAML files specified by --base, merges them via OmegaConf.merge, and extracts model, data, deepspeed, and args sections. Values from the args section override the parsed CLI arguments.
  4. Default handling: Sets default train_iters to 10000 if neither train_iters nor epochs is specified.
  5. DeepSpeed configuration: If no deepspeed_config is provided, loads a default JSON config based on args.zero_stage. Synchronizes batch size, learning rate, weight decay, and precision settings between the args namespace and the DeepSpeed config dictionary.
  6. Distributed initialization: Calls initialize_distributed to set up torch.distributed, model parallelism, and context parallelism groups.
  7. Random seed: Sets the random seed with a per-rank offset for data-parallel diversity.

The companion function process_config_to_args handles the YAML-to-args translation by:

  1. Loading each YAML file listed in args.base using OmegaConf.load.
  2. Merging all configs using OmegaConf.merge.
  3. Extracting and assigning model_config, data_config, and deepspeed_config to the args namespace.
  4. Overriding args attributes with values from the YAML args section.

Usage

Import get_args when building the entry point for a SAT training script. It is called once at the beginning of train_video.py to produce the complete configuration namespace that is passed to all downstream components.

Code Reference

Source Location

  • sat/arguments.py:L58-185 (get_args)
  • sat/arguments.py:L271-298 (process_config_to_args)

Signature

def get_args(args_list: Optional[List[str]] = None, parser: Optional[argparse.ArgumentParser] = None) -> argparse.Namespace:
    """
    Parse arguments from CLI and YAML configs.

    Args:
        args_list: Optional list of argument strings. If None, uses sys.argv.
        parser: Optional pre-configured ArgumentParser to extend.

    Returns:
        argparse.Namespace with model_config, data_config, deepspeed_config
        and all CLI/YAML args merged.
    """

def process_config_to_args(args: argparse.Namespace) -> argparse.Namespace:
    """
    Load YAML files from args.base, merge them, and extract model, data,
    and deepspeed configs into the args namespace.
    """

Import

from arguments import get_args  # within sat/ directory

I/O Contract

Inputs

Parameter Type Required Description
--base List[str] Yes One or more YAML config file paths to load and merge. Example: configs/cogvideox_2b_lora.yaml configs/sft.yaml.
--model-parallel-size int No Size of model parallelism group. Default: 1. Only use if expert in tensor parallelism.
--sampling-num-frames int No Number of frames for sampling during inference. Default: 32.
--latent-channels int No Number of latent channels in the VAE. Default: 32.
--device int No GPU device index. Default: -1 (auto-detect from LOCAL_RANK).
--force-pretrain flag No Force loading pretrained weights.
--debug flag No Enable debug logging.
--image2video flag No Enable image-to-video mode.

Outputs

Output Type Description
args argparse.Namespace Complete configuration namespace containing: model_config (OmegaConf DictConfig with network, denoiser, sampler, conditioner, VAE, loss configs), data_config (OmegaConf DictConfig with dataset target and params), deepspeed_config (dict with optimizer, precision, ZeRO settings), plus all CLI arguments (train_iters, batch_size, lr, save, etc.).

Usage Examples

Standard Usage in train_video.py

import argparse
from arguments import get_args

py_parser = argparse.ArgumentParser(add_help=False)
known, args_list = py_parser.parse_known_args()
args = get_args(args_list)
args = argparse.Namespace(**vars(args), **vars(known))

# args.model_config now contains the full model architecture
# args.data_config now contains dataset class and params
# args.deepspeed_config now contains optimizer and training settings

CLI Invocation

python train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed 42

External Dependencies

  • omegaconf: YAML loading, merging, and structured configuration.
  • argparse: CLI argument parsing.
  • deepspeed: Adds DeepSpeed-specific argument groups to the parser.
  • sat: Provides add_training_args, add_evaluation_args, add_data_args, set_random_seed, and mpu (model parallelism utilities).

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment