Implementation:Zai org CogVideo Trainer Train
| Implementation Metadata | |
|---|---|
| Name | Trainer_Train |
| Type | API Doc |
| Category | Training |
| Domains | Video_Generation, Fine_Tuning, Diffusion_Models |
| Knowledge Sources | CogVideo Repository, CogVideoX Paper |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Trainer_Train is a concrete tool for executing the CogVideoX diffusion training loop, provided by the CogVideo finetune package.
Description
This implementation provides the main train() method in the base Trainer class and the model-specific compute_loss() method in the LoRA trainer subclasses. The train() method orchestrates the epoch loop, batch iteration, gradient accumulation, logging, checkpointing, and validation calls. The compute_loss() method implements the CogVideoX-specific forward pass: sampling random timesteps, adding noise to video latents, computing the transformer's prediction, and returning the weighted MSE loss.
Usage
Use when running the actual fine-tuning process after all components have been initialized and prepared. The train() method is called as the final step in the training pipeline.
Code Reference
Source Location
finetune/trainer.py:L372-501--train()method (main training loop)finetune/models/cogvideox_t2v/lora_trainer.py:L103-176--compute_loss()for T2Vfinetune/models/cogvideox_i2v/lora_trainer.py--compute_loss()for I2V
Signature
class Trainer:
def train(self) -> None:
"""Main training loop.
Iterates over epochs and batches, computes loss via compute_loss(),
handles gradient accumulation, logging, checkpointing, and validation.
"""
...
class CogVideoXT2VLoraTrainer(Trainer):
def compute_loss(self, batch) -> torch.Tensor:
"""Compute diffusion training loss for a single batch.
Args:
batch: Dictionary containing:
- "prompt_embedding": [B, seq_len, hidden_size]
- "encoded_videos": [B, C, F, H, W]
Returns:
Scalar loss tensor.
"""
...
Import
from finetune.trainer import Trainer
from finetune.models.cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer
Key Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
train_epochs |
int |
from args | Number of training epochs. |
train_steps |
int |
from args | Maximum training steps (alternative to epochs). |
batch_size |
int |
from args | Per-GPU batch size. |
max_grad_norm |
float |
1.0 |
Maximum gradient norm for clipping. |
gradient_accumulation_steps |
int |
1 |
Number of micro-batches per optimizer step. |
External Dependencies
torch-- Core tensor operations and autograddiffusers--scheduler.add_noise,scheduler.get_velocityfor noise schedulingtqdm-- Progress bar display
I/O Contract
Inputs
| Input | Format | Description |
|---|---|---|
| Video latents | Tensor [B, C, F, H, W] |
Pre-encoded video latents from the VAE encoder (batch key: "encoded_videos").
|
| Prompt embeddings | Tensor [B, seq_len, hidden_size] |
Pre-encoded T5 text embeddings (batch key: "prompt_embedding").
|
Outputs
| Output | Format | Description |
|---|---|---|
| Updated LoRA weights | In-place parameter updates | Transformer LoRA adapter parameters are updated via optimizer steps. |
| Training loss | Scalar float (logged) |
Per-step and averaged training loss reported to logger (wandb/tensorboard). |
| Gradient norms | Scalar float (logged) |
Pre-clip gradient norms reported for monitoring training stability. |
Usage Examples
Running the Training Loop
from finetune.schemas import Args
from finetune.models.cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer
# Parse and validate configuration
args = Args.parse_args()
# Initialize trainer (loads components, injects LoRA, sets up accelerator)
trainer = CogVideoXT2VLoraTrainer(args=args)
# Run training
trainer.train()
# Training loop handles:
# - Epoch and batch iteration
# - Random timestep sampling
# - Noise addition to video latents
# - Forward pass through transformer
# - Weighted MSE loss computation
# - Gradient accumulation and clipping
# - Optimizer and LR scheduler steps
# - Periodic checkpointing and validation
Understanding the compute_loss Flow
# Pseudocode for compute_loss (T2V):
def compute_loss(self, batch):
video_latents = batch["encoded_videos"] # [B, C, F, H, W]
prompt_embeds = batch["prompt_embedding"] # [B, seq_len, hidden_size]
# Sample random timesteps
timesteps = torch.randint(0, num_train_timesteps, (B,))
# Add noise to latents
noise = torch.randn_like(video_latents)
noisy_latents = scheduler.add_noise(video_latents, noise, timesteps)
# Compute target (v-prediction)
target = scheduler.get_velocity(video_latents, noise, timesteps)
# Forward pass through transformer
model_output = transformer(noisy_latents, timesteps, prompt_embeds)
# Compute weighted MSE loss
loss = F.mse_loss(model_output, target, reduction="mean")
return loss
Related Pages
- Principle:Zai_org_CogVideo_Diffusion_Training_Loop
- Environment:Zai_org_CogVideo_Diffusers_Finetuning_Environment
- Heuristic:Zai_org_CogVideo_BF16_FP16_Precision_Selection
- Heuristic:Zai_org_CogVideo_Memory_Optimization_Strategies
- Heuristic:Zai_org_CogVideo_Training_Hyperparameter_Defaults
- Implementation:Zai_org_CogVideo_CogVideoX_LoRA_Trainer_Load_Components
- Implementation:Zai_org_CogVideo_Accelerator_Setup
- Implementation:Zai_org_CogVideo_Trainer_Checkpoint_Validate