Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Zai org CogVideo SAT Get Model Load Checkpoint

From Leeroopedia


Attribute Value
Implementation Name SAT Get Model Load Checkpoint
Workflow SAT Video Generation
Step 2 of 5
Type Wrapper Doc
Source File sat/sample_video.py:L137-144
Repository zai-org/CogVideo
External Dependencies sat
Last Updated 2026-02-10 00:00 GMT

Overview

Implementation of model loading for SAT-based CogVideoX inference. This wrapper combines SAT framework's get_model and load_checkpoint functions to instantiate the SATVideoDiffusionEngine with pretrained weights and prepare it for inference.

Description

The model loading implementation consists of three sequential calls:

  1. get_model(args, model_cls=SATVideoDiffusionEngine) -- Constructs the model architecture from configuration
  2. load_checkpoint(model, args) -- Loads pretrained weights from the checkpoint path
  3. model.eval() -- Sets the model to evaluation mode

The get_model function reads model architecture parameters from args.model_config (populated from the YAML config file) and instantiates the specified model class. The load_checkpoint function reads weights from args.load and applies them to the model.

Usage

from sat.model.base_model import get_model
from sat.training.model_io import load_checkpoint
from diffusion_video import SATVideoDiffusionEngine

# args is obtained from get_args()
model = get_model(args, model_cls=SATVideoDiffusionEngine)
load_checkpoint(model, args)
model.eval()

# Model is now ready for inference
with torch.no_grad():
    samples = model.sample(cond, uc, batch_size=1, shape=shape)

Code Reference

Source Location

File Lines Description
sat/sample_video.py L137-144 Model loading wrapper

Signature

from sat.model.base_model import get_model
from sat.training.model_io import load_checkpoint

model = get_model(args, model_cls=SATVideoDiffusionEngine)
load_checkpoint(model, args)
model.eval()

Import

from sat.model.base_model import get_model
from sat.training.model_io import load_checkpoint

I/O Contract

Inputs

Parameter Type Default Description
args argparse.Namespace Required Parsed arguments containing model config and checkpoint path
args.model_config dict From YAML Model architecture configuration loaded from YAML
args.load str Required Path to pretrained checkpoint directory
model_cls type SATVideoDiffusionEngine Model class to instantiate

Outputs

Output Type Description
model SATVideoDiffusionEngine Loaded model in eval mode, ready for inference

Usage Examples

Example 1: Standard model loading

from arguments import get_args
from sat.model.base_model import get_model
from sat.training.model_io import load_checkpoint
from diffusion_video import SATVideoDiffusionEngine

args = get_args()
model = get_model(args, model_cls=SATVideoDiffusionEngine)
load_checkpoint(model, args)
model.eval()

Example 2: Model loading with distributed setup

import torch
from arguments import get_args
from sat.model.base_model import get_model
from sat.training.model_io import load_checkpoint
from diffusion_video import SATVideoDiffusionEngine

args = get_args()
model = get_model(args, model_cls=SATVideoDiffusionEngine)
load_checkpoint(model, args)
model.eval()

# Distributed inference uses SAT's built-in model parallelism
# configured via args (--model-parallel-size, etc.)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment