Implementation:Zai org CogVideo SATVideoDiffusionEngine Init
Metadata
| Field | Value |
|---|---|
| Page Type | Implementation (API Doc) |
| Knowledge Sources | CogVideo |
| Domains | Model_Architecture, Video_Diffusion |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Concrete tool for constructing the SAT video diffusion engine from YAML configuration provided by the CogVideo SAT module.
Description
SATVideoDiffusionEngine is the top-level nn.Module that orchestrates all components of the CogVideoX video diffusion model. Its __init__ method reads the model_config from the parsed args and constructs each sub-component using instantiate_from_config.
The constructor performs the following steps in order:
- Extracts all configuration sub-dictionaries from
args.model_config. - Determines compute precision (
fp16,bf16, orfp32) from args. - Constructs the DiT backbone via
instantiate_from_config(network_config)and wraps it withOPENAIUNETWRAPPER. - Constructs the denoiser via
instantiate_from_config(denoiser_config). - Constructs the sampler via
instantiate_from_config(sampler_config)(if provided). - Constructs the conditioner via
instantiate_from_config(conditioner_config). - Initializes the VAE first stage model in eval mode with frozen parameters.
- Constructs the loss function via
instantiate_from_config(loss_fn_config)(if provided).
The disable_untrainable_params method handles parameter freezing after construction, supporting both LoRA training (setting lr_scale=0 for non-LoRA params) and full fine-tuning (setting requires_grad=False for frozen prefixes).
Usage
Import SATVideoDiffusionEngine to construct the complete model from parsed YAML configuration. It is passed as the model_cls argument to training_main, which instantiates and manages it during distributed training.
Code Reference
Source Location
sat/diffusion_video.py:L25-98(__init__)sat/diffusion_video.py:L101-132(disable_untrainable_params)
Signature
class SATVideoDiffusionEngine(nn.Module):
def __init__(self, args, **kwargs):
"""
Constructs the complete video diffusion model from args.model_config.
Expected model_config keys:
network_config: DiT backbone architecture
denoiser_config: noise prediction strategy
sampler_config: inference sampling algorithm
conditioner_config: text conditioning pipeline
first_stage_config: 3D VAE for video encoding/decoding
loss_fn_config: training loss computation
scale_factor: float, latent space scaling (1.15258426 for 2B)
lora_train: bool, enable LoRA training mode
not_trainable_prefixes: List[str], parameter prefixes to freeze
noised_image_input: bool, enable I2V noised image conditioning
log_keys: Optional[List[str]], keys to log during training
input_key: str, batch key for video data (default "mp4")
"""
def disable_untrainable_params(self):
"""
Freeze parameters based on training mode.
- LoRA mode: sets lr_scale=0 for non-lora_layer params
- Full fine-tuning: sets requires_grad=False for not_trainable_prefixes
Prints total trainable parameter count.
"""
Import
from diffusion_video import SATVideoDiffusionEngine # within sat/ directory
I/O Contract
Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
args |
argparse.Namespace |
Yes | Parsed configuration namespace containing model_config (OmegaConf DictConfig), fp16, bf16, and device attributes.
|
Key model_config sub-keys:
| Config Key | Description |
|---|---|
network_config |
DiT backbone: target class, hidden_size, num_layers, num_attention_heads, patch_size, in_channels, LoRA config, positional embedding config. |
denoiser_config |
Denoiser: DiscreteDenoiser with 1000 timesteps, EpsWeighting, VideoScaling, ZeroSNR-DDPM discretization. |
sampler_config |
Sampler: VPSDEDPMPP2MSampler with 50 steps, DynamicCFG guidance. |
conditioner_config |
Conditioner: GeneralConditioner with frozen T5-XXL, max_length=226, UCG rate=0.1. |
first_stage_config |
VAE: VideoAutoencoderInferenceWrapper with context-parallel encoder/decoder. |
loss_fn_config |
Loss: VideoDiffusionLoss with DiscreteSampling and ZeroSNR-DDPM. |
scale_factor |
Latent scaling: 1.15258426 (2B) or 0.7 (5B). |
lora_train |
Boolean: enables LoRA training mode. |
noised_image_input |
Boolean: enables image-to-video noised image conditioning. |
Outputs
| Output | Type | Description |
|---|---|---|
| Engine instance | SATVideoDiffusionEngine |
A complete nn.Module with attributes: model (wrapped DiT backbone), denoiser, sampler, conditioner, first_stage_model (frozen VAE), loss_fn, dtype, scale_factor.
|
Usage Examples
Construction via training_main
from sat.training.deepspeed_training import training_main
from diffusion_video import SATVideoDiffusionEngine
from arguments import get_args
args = get_args(args_list)
# training_main internally calls SATVideoDiffusionEngine(args) to construct the model
training_main(
args,
model_cls=SATVideoDiffusionEngine,
forward_step_function=forward_step,
create_dataset_function=create_dataset_function,
)
Key Module Attributes After Init
engine = SATVideoDiffusionEngine(args)
# Core components
engine.model # Wrapped DiffusionTransformer backbone
engine.denoiser # DiscreteDenoiser
engine.sampler # VPSDEDPMPP2MSampler
engine.conditioner # GeneralConditioner with T5-XXL
engine.first_stage_model # Frozen 3D VAE
engine.loss_fn # VideoDiffusionLoss
# Configuration
engine.dtype # torch.float16, torch.bfloat16, or torch.float32
engine.scale_factor # 1.15258426 (for 2B)
engine.lora_train # True for LoRA training
engine.noised_image_input # True for I2V
External Dependencies
- sat: SwissArmyTransformer framework for model parallelism utilities.
- sgm.util: Provides
instantiate_from_config,get_obj_from_str,default,disabled_train. - torch.nn: PyTorch neural network module base class.