Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Zai org CogVideo Trainer Checkpoint Validate

From Leeroopedia


Implementation Metadata
Name Trainer_Checkpoint_Validate
Type API Doc
Category Training
Domains Fine_Tuning, Diffusion_Models
Knowledge Sources CogVideo Repository, CogVideoX Paper
Last Updated 2026-02-10 00:00 GMT

Overview

Trainer_Checkpoint_Validate is a concrete tool for checkpoint saving and validation video generation during CogVideoX training, provided by the CogVideo finetune package.

Description

This implementation provides the checkpointing and validation functionality within the Trainer class. The __maybe_save_checkpoint method periodically saves full training state to disk using Accelerate's save_state with safe serialization. The validate method constructs an inference pipeline from the current model state, generates sample videos from held-out prompts, exports them to files, and logs them to experiment trackers.

Usage

Use during CogVideoX fine-tuning to periodically save model state and generate validation samples. Both checkpointing and validation are triggered automatically by the training loop at configured step intervals. Checkpointing is also triggered at the end of training.

Code Reference

Source Location

  • finetune/trainer.py:L799-811 -- __maybe_save_checkpoint method
  • finetune/utils/checkpointing.py:L15-57 -- Checkpoint utility functions
  • finetune/trainer.py:L502-668 -- validate method

Signature

Checkpointing:

class Trainer:
    def __maybe_save_checkpoint(self, step: int) -> None:
        """Save checkpoint if current step matches checkpointing interval."""
        if step % self.args.checkpointing_steps == 0:
            save_path = os.path.join(
                self.args.output_dir,
                f"checkpoint-{step}"
            )
            self.accelerator.save_state(save_path, safe_serialization=True)

Validation:

class Trainer:
    def validate(self, step: int) -> None:
        """Generate validation videos using current model state.

        Args:
            step: Current training step (used for naming output files).
        """
        # Constructs inference pipeline
        # Generates videos from validation prompts
        # Exports to output_dir/validation_res/
        # Logs to wandb if configured
        ...

Import

from finetune.trainer import Trainer

Key Parameters

Parameter Type Default Description
checkpointing_steps int 200 Save a checkpoint every N training steps.
checkpointing_limit int 10 Maximum number of rolling checkpoints to retain (oldest are deleted).
do_validation bool True Whether to run validation during training.
validation_steps int 200 Run validation every N training steps.
validation_dir Path output_dir/validation_res Directory for saving validation video outputs.
gen_fps int 15 Frames per second for exported validation videos.

External Dependencies

  • accelerate -- accelerator.save_state() for distributed checkpoint saving
  • wandb -- (optional) Experiment tracking and video logging
  • diffusers.utils.export_utils.export_to_video -- Video file export

I/O Contract

Inputs

Input Format Description
Current model state In-memory model parameters The transformer with LoRA adapters at the current training step.
Optimizer state In-memory optimizer state Adam optimizer first and second moments, step counts.
Validation prompts Text strings Held-out prompts for generating validation videos (configured via validation parameters).
Validation images/videos Image/video files (I2V only) Reference images or videos for image-to-video validation.

Outputs

Output Format Description
Checkpoint files Directory output_dir/checkpoint-{step}/ Full training state in .safetensors format (model, optimizer, scheduler states).
Validation videos .mp4 files in output_dir/validation_res/ Generated videos from validation prompts at the current training step.
Logged media wandb media logs Videos logged to wandb for remote viewing (if wandb logging is enabled).

Usage Examples

Configuring Checkpointing and Validation

python train.py \
    --model_path /models/cogvideox-5b \
    --model_name cogvideox-5b \
    --model_type t2v \
    --training_type lora \
    --output_dir /output/lora_run \
    --data_root /data/my_videos \
    --caption_column prompts.txt \
    --video_column videos.txt \
    --train_epochs 100 \
    --batch_size 1 \
    --train_resolution 49 480 720 \
    --mixed_precision bf16 \
    --checkpointing_steps 200 \
    --checkpointing_limit 10 \
    --do_validation \
    --validation_steps 200

Resuming from a Checkpoint

from finetune.schemas import Args
from finetune.models.cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer

args = Args.parse_args()
trainer = CogVideoXT2VLoraTrainer(args=args)

# Resume from specific checkpoint
trainer.accelerator.load_state("/output/lora_run/checkpoint-1000")
trainer.train()

Checkpoint Directory Structure

output_dir/
  checkpoint-200/
    model.safetensors
    optimizer.bin
    scheduler.bin
    random_states_0.pkl
  checkpoint-400/
    ...
  validation_res/
    step_200_prompt_0.mp4
    step_200_prompt_1.mp4
    step_400_prompt_0.mp4
    ...

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment