Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Zai org CogVideo SATVideoDiffusionEngine Init

From Leeroopedia


Metadata

Field Value
Page Type Implementation (API Doc)
Knowledge Sources CogVideo
Domains Model_Architecture, Video_Diffusion
Last Updated 2026-02-10 00:00 GMT

Overview

Concrete tool for constructing the SAT video diffusion engine from YAML configuration provided by the CogVideo SAT module.

Description

SATVideoDiffusionEngine is the top-level nn.Module that orchestrates all components of the CogVideoX video diffusion model. Its __init__ method reads the model_config from the parsed args and constructs each sub-component using instantiate_from_config.

The constructor performs the following steps in order:

  1. Extracts all configuration sub-dictionaries from args.model_config.
  2. Determines compute precision (fp16, bf16, or fp32) from args.
  3. Constructs the DiT backbone via instantiate_from_config(network_config) and wraps it with OPENAIUNETWRAPPER.
  4. Constructs the denoiser via instantiate_from_config(denoiser_config).
  5. Constructs the sampler via instantiate_from_config(sampler_config) (if provided).
  6. Constructs the conditioner via instantiate_from_config(conditioner_config).
  7. Initializes the VAE first stage model in eval mode with frozen parameters.
  8. Constructs the loss function via instantiate_from_config(loss_fn_config) (if provided).

The disable_untrainable_params method handles parameter freezing after construction, supporting both LoRA training (setting lr_scale=0 for non-LoRA params) and full fine-tuning (setting requires_grad=False for frozen prefixes).

Usage

Import SATVideoDiffusionEngine to construct the complete model from parsed YAML configuration. It is passed as the model_cls argument to training_main, which instantiates and manages it during distributed training.

Code Reference

Source Location

  • sat/diffusion_video.py:L25-98 (__init__)
  • sat/diffusion_video.py:L101-132 (disable_untrainable_params)

Signature

class SATVideoDiffusionEngine(nn.Module):
    def __init__(self, args, **kwargs):
        """
        Constructs the complete video diffusion model from args.model_config.

        Expected model_config keys:
            network_config: DiT backbone architecture
            denoiser_config: noise prediction strategy
            sampler_config: inference sampling algorithm
            conditioner_config: text conditioning pipeline
            first_stage_config: 3D VAE for video encoding/decoding
            loss_fn_config: training loss computation
            scale_factor: float, latent space scaling (1.15258426 for 2B)
            lora_train: bool, enable LoRA training mode
            not_trainable_prefixes: List[str], parameter prefixes to freeze
            noised_image_input: bool, enable I2V noised image conditioning
            log_keys: Optional[List[str]], keys to log during training
            input_key: str, batch key for video data (default "mp4")
        """

    def disable_untrainable_params(self):
        """
        Freeze parameters based on training mode.
        - LoRA mode: sets lr_scale=0 for non-lora_layer params
        - Full fine-tuning: sets requires_grad=False for not_trainable_prefixes
        Prints total trainable parameter count.
        """

Import

from diffusion_video import SATVideoDiffusionEngine  # within sat/ directory

I/O Contract

Inputs

Parameter Type Required Description
args argparse.Namespace Yes Parsed configuration namespace containing model_config (OmegaConf DictConfig), fp16, bf16, and device attributes.

Key model_config sub-keys:

Config Key Description
network_config DiT backbone: target class, hidden_size, num_layers, num_attention_heads, patch_size, in_channels, LoRA config, positional embedding config.
denoiser_config Denoiser: DiscreteDenoiser with 1000 timesteps, EpsWeighting, VideoScaling, ZeroSNR-DDPM discretization.
sampler_config Sampler: VPSDEDPMPP2MSampler with 50 steps, DynamicCFG guidance.
conditioner_config Conditioner: GeneralConditioner with frozen T5-XXL, max_length=226, UCG rate=0.1.
first_stage_config VAE: VideoAutoencoderInferenceWrapper with context-parallel encoder/decoder.
loss_fn_config Loss: VideoDiffusionLoss with DiscreteSampling and ZeroSNR-DDPM.
scale_factor Latent scaling: 1.15258426 (2B) or 0.7 (5B).
lora_train Boolean: enables LoRA training mode.
noised_image_input Boolean: enables image-to-video noised image conditioning.

Outputs

Output Type Description
Engine instance SATVideoDiffusionEngine A complete nn.Module with attributes: model (wrapped DiT backbone), denoiser, sampler, conditioner, first_stage_model (frozen VAE), loss_fn, dtype, scale_factor.

Usage Examples

Construction via training_main

from sat.training.deepspeed_training import training_main
from diffusion_video import SATVideoDiffusionEngine
from arguments import get_args

args = get_args(args_list)

# training_main internally calls SATVideoDiffusionEngine(args) to construct the model
training_main(
    args,
    model_cls=SATVideoDiffusionEngine,
    forward_step_function=forward_step,
    create_dataset_function=create_dataset_function,
)

Key Module Attributes After Init

engine = SATVideoDiffusionEngine(args)

# Core components
engine.model           # Wrapped DiffusionTransformer backbone
engine.denoiser        # DiscreteDenoiser
engine.sampler         # VPSDEDPMPP2MSampler
engine.conditioner     # GeneralConditioner with T5-XXL
engine.first_stage_model  # Frozen 3D VAE
engine.loss_fn         # VideoDiffusionLoss

# Configuration
engine.dtype           # torch.float16, torch.bfloat16, or torch.float32
engine.scale_factor    # 1.15258426 (for 2B)
engine.lora_train      # True for LoRA training
engine.noised_image_input  # True for I2V

External Dependencies

  • sat: SwissArmyTransformer framework for model parallelism utilities.
  • sgm.util: Provides instantiate_from_config, get_obj_from_str, default, disabled_train.
  • torch.nn: PyTorch neural network module base class.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment