Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Zai org CogVideo Trainer Train

From Leeroopedia


Implementation Metadata
Name Trainer_Train
Type API Doc
Category Training
Domains Video_Generation, Fine_Tuning, Diffusion_Models
Knowledge Sources CogVideo Repository, CogVideoX Paper
Last Updated 2026-02-10 00:00 GMT

Overview

Trainer_Train is a concrete tool for executing the CogVideoX diffusion training loop, provided by the CogVideo finetune package.

Description

This implementation provides the main train() method in the base Trainer class and the model-specific compute_loss() method in the LoRA trainer subclasses. The train() method orchestrates the epoch loop, batch iteration, gradient accumulation, logging, checkpointing, and validation calls. The compute_loss() method implements the CogVideoX-specific forward pass: sampling random timesteps, adding noise to video latents, computing the transformer's prediction, and returning the weighted MSE loss.

Usage

Use when running the actual fine-tuning process after all components have been initialized and prepared. The train() method is called as the final step in the training pipeline.

Code Reference

Source Location

  • finetune/trainer.py:L372-501 -- train() method (main training loop)
  • finetune/models/cogvideox_t2v/lora_trainer.py:L103-176 -- compute_loss() for T2V
  • finetune/models/cogvideox_i2v/lora_trainer.py -- compute_loss() for I2V

Signature

class Trainer:
    def train(self) -> None:
        """Main training loop.

        Iterates over epochs and batches, computes loss via compute_loss(),
        handles gradient accumulation, logging, checkpointing, and validation.
        """
        ...

class CogVideoXT2VLoraTrainer(Trainer):
    def compute_loss(self, batch) -> torch.Tensor:
        """Compute diffusion training loss for a single batch.

        Args:
            batch: Dictionary containing:
                - "prompt_embedding": [B, seq_len, hidden_size]
                - "encoded_videos": [B, C, F, H, W]

        Returns:
            Scalar loss tensor.
        """
        ...

Import

from finetune.trainer import Trainer
from finetune.models.cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer

Key Parameters

Parameter Type Default Description
train_epochs int from args Number of training epochs.
train_steps int from args Maximum training steps (alternative to epochs).
batch_size int from args Per-GPU batch size.
max_grad_norm float 1.0 Maximum gradient norm for clipping.
gradient_accumulation_steps int 1 Number of micro-batches per optimizer step.

External Dependencies

  • torch -- Core tensor operations and autograd
  • diffusers -- scheduler.add_noise, scheduler.get_velocity for noise scheduling
  • tqdm -- Progress bar display

I/O Contract

Inputs

Input Format Description
Video latents Tensor [B, C, F, H, W] Pre-encoded video latents from the VAE encoder (batch key: "encoded_videos").
Prompt embeddings Tensor [B, seq_len, hidden_size] Pre-encoded T5 text embeddings (batch key: "prompt_embedding").

Outputs

Output Format Description
Updated LoRA weights In-place parameter updates Transformer LoRA adapter parameters are updated via optimizer steps.
Training loss Scalar float (logged) Per-step and averaged training loss reported to logger (wandb/tensorboard).
Gradient norms Scalar float (logged) Pre-clip gradient norms reported for monitoring training stability.

Usage Examples

Running the Training Loop

from finetune.schemas import Args
from finetune.models.cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer

# Parse and validate configuration
args = Args.parse_args()

# Initialize trainer (loads components, injects LoRA, sets up accelerator)
trainer = CogVideoXT2VLoraTrainer(args=args)

# Run training
trainer.train()
# Training loop handles:
#   - Epoch and batch iteration
#   - Random timestep sampling
#   - Noise addition to video latents
#   - Forward pass through transformer
#   - Weighted MSE loss computation
#   - Gradient accumulation and clipping
#   - Optimizer and LR scheduler steps
#   - Periodic checkpointing and validation

Understanding the compute_loss Flow

# Pseudocode for compute_loss (T2V):
def compute_loss(self, batch):
    video_latents = batch["encoded_videos"]       # [B, C, F, H, W]
    prompt_embeds = batch["prompt_embedding"]      # [B, seq_len, hidden_size]

    # Sample random timesteps
    timesteps = torch.randint(0, num_train_timesteps, (B,))

    # Add noise to latents
    noise = torch.randn_like(video_latents)
    noisy_latents = scheduler.add_noise(video_latents, noise, timesteps)

    # Compute target (v-prediction)
    target = scheduler.get_velocity(video_latents, noise, timesteps)

    # Forward pass through transformer
    model_output = transformer(noisy_latents, timesteps, prompt_embeds)

    # Compute weighted MSE loss
    loss = F.mse_loss(model_output, target, reduction="mean")
    return loss

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment