Implementation:Zai org CogVideo SAT Get Model Load Checkpoint
| Attribute | Value |
|---|---|
| Implementation Name | SAT Get Model Load Checkpoint |
| Workflow | SAT Video Generation |
| Step | 2 of 5 |
| Type | Wrapper Doc |
| Source File | sat/sample_video.py:L137-144
|
| Repository | zai-org/CogVideo |
| External Dependencies | sat |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Implementation of model loading for SAT-based CogVideoX inference. This wrapper combines SAT framework's get_model and load_checkpoint functions to instantiate the SATVideoDiffusionEngine with pretrained weights and prepare it for inference.
Description
The model loading implementation consists of three sequential calls:
get_model(args, model_cls=SATVideoDiffusionEngine)-- Constructs the model architecture from configurationload_checkpoint(model, args)-- Loads pretrained weights from the checkpoint pathmodel.eval()-- Sets the model to evaluation mode
The get_model function reads model architecture parameters from args.model_config (populated from the YAML config file) and instantiates the specified model class. The load_checkpoint function reads weights from args.load and applies them to the model.
Usage
from sat.model.base_model import get_model
from sat.training.model_io import load_checkpoint
from diffusion_video import SATVideoDiffusionEngine
# args is obtained from get_args()
model = get_model(args, model_cls=SATVideoDiffusionEngine)
load_checkpoint(model, args)
model.eval()
# Model is now ready for inference
with torch.no_grad():
samples = model.sample(cond, uc, batch_size=1, shape=shape)
Code Reference
Source Location
| File | Lines | Description |
|---|---|---|
sat/sample_video.py |
L137-144 | Model loading wrapper |
Signature
from sat.model.base_model import get_model
from sat.training.model_io import load_checkpoint
model = get_model(args, model_cls=SATVideoDiffusionEngine)
load_checkpoint(model, args)
model.eval()
Import
from sat.model.base_model import get_model
from sat.training.model_io import load_checkpoint
I/O Contract
Inputs
| Parameter | Type | Default | Description |
|---|---|---|---|
args |
argparse.Namespace |
Required | Parsed arguments containing model config and checkpoint path |
args.model_config |
dict |
From YAML | Model architecture configuration loaded from YAML |
args.load |
str |
Required | Path to pretrained checkpoint directory |
model_cls |
type |
SATVideoDiffusionEngine |
Model class to instantiate |
Outputs
| Output | Type | Description |
|---|---|---|
model |
SATVideoDiffusionEngine |
Loaded model in eval mode, ready for inference |
Usage Examples
Example 1: Standard model loading
from arguments import get_args
from sat.model.base_model import get_model
from sat.training.model_io import load_checkpoint
from diffusion_video import SATVideoDiffusionEngine
args = get_args()
model = get_model(args, model_cls=SATVideoDiffusionEngine)
load_checkpoint(model, args)
model.eval()
Example 2: Model loading with distributed setup
import torch
from arguments import get_args
from sat.model.base_model import get_model
from sat.training.model_io import load_checkpoint
from diffusion_video import SATVideoDiffusionEngine
args = get_args()
model = get_model(args, model_cls=SATVideoDiffusionEngine)
load_checkpoint(model, args)
model.eval()
# Distributed inference uses SAT's built-in model parallelism
# configured via args (--model-parallel-size, etc.)
Related Pages
- Principle:Zai_org_CogVideo_SAT_Model_Loading_for_Inference -- Principle governing model loading for inference
- Environment:Zai_org_CogVideo_SAT_Framework_Environment
- Zai_org_CogVideo_SAT_Inference_Get_Args -- Previous step: argument parsing that provides model config
- Zai_org_CogVideo_SAT_Read_From_CLI_File -- Next step: reading prompts for generation
- Zai_org_CogVideo_SAT_Diffusion_Sample -- Sampling step using the loaded model