Implementation:Zai org CogVideo Video Autoencoder Loss
| Knowledge Sources | |
|---|---|
| Domains | Video_Generation, Autoencoding, Adversarial_Training |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
VideoAutoencoderLoss is a comprehensive video autoencoder training loss that combines MSE reconstruction, LPIPS perceptual similarity, adversarial losses from 2D and 3D discriminators, and gradient penalty regularization with adaptive weighting.
Description
This module is the primary loss computation class for CogVideo's video autoencoder. It contains and orchestrates multiple discriminator architectures and loss components:
Discriminator Architectures:
- Discriminator (2D): A multi-layer convolutional discriminator with anti-aliased downsampling via Blur (Zhang et al. blurpool), LinearSpaceAttention blocks, and FeedForward layers at each scale. Processes single frames extracted from video.
- Discriminator3D: A hybrid discriminator that uses DiscriminatorBlock3D (3D convolutions with 3D downsampling) for early layers to capture temporal structure, then transitions to 2D DiscriminatorBlock with attention for later layers. The temporal dimension is collapsed via reshape after the 3D stages.
- Discriminator3DWithfirstframe: Similar to Discriminator3D but uses CausalConv3d and DownSample3D from the MOVQ encoder for first-frame-aware causal processing, collapsing the temporal dimension via mean pooling.
Loss Components:
- MSE reconstruction loss: Computed pixel-wise between input and reconstruction via
F.mse_loss. - LPIPS perceptual loss: Applied to randomly sampled frames using pick_video_frame to select frames via top-k on random noise. Weighted by perceptual_weight.
- Adversarial generator loss: Hinge-based generator loss from the 3D discriminator. Adaptive weighting computes the ratio of gradient norms of the perceptual loss and generator loss with respect to the last decoder layer.
- Discriminator loss: Hinge discriminator loss on real vs. fake 3D video logits, with optional gradient penalty regularization on real inputs.
- Quantizer auxiliary loss: Passed in as aux_losses and weighted by quantizer_aux_loss_weight.
The forward method dispatches on optimizer_idx: 0 for the generator (autoencoder) update and 1 for the discriminator update.
Usage
Use this loss when training a video autoencoder that requires both spatial quality (via perceptual and 2D adversarial losses) and temporal coherence (via 3D discriminator losses). This is the dedicated video loss for CogVideo's autoencoder pipeline, complementing the image-focused GeneralLPIPSWithDiscriminator.
Code Reference
Source Location
- Repository: Zai_org_CogVideo
- File: sat/sgm/modules/autoencoding/losses/video_loss.py
- Lines: 516-736
Signature
class VideoAutoencoderLoss(nn.Module):
def __init__(
self,
disc_start,
perceptual_weight=1,
adversarial_loss_weight=0,
multiscale_adversarial_loss_weight=0,
grad_penalty_loss_weight=0,
quantizer_aux_loss_weight=0,
vgg_weights=VGG16_Weights.DEFAULT,
discr_kwargs=None,
discr_3d_kwargs=None,
):
Import
from sat.sgm.modules.autoencoding.losses.video_loss import VideoAutoencoderLoss
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| inputs | torch.Tensor | Yes | Original video tensor of shape (B, C, T, H, W)
|
| reconstructions | torch.Tensor | Yes | Reconstructed video tensor, same shape as inputs |
| optimizer_idx | int | Yes | 0 for generator (autoencoder) update, 1 for discriminator update |
| global_step | int | Yes | Current training step; adversarial loss activates after disc_start |
| aux_losses | torch.Tensor | No | Auxiliary losses from the quantizer; defaults to zero if None |
| last_layer | torch.Tensor | No | Last decoder layer weight for adaptive weighting; required when using adaptive adversarial weighting |
| split | str | No | Logging prefix string, defaults to "train"
|
Outputs
| Name | Type | Description |
|---|---|---|
| total_loss | torch.Tensor | Scalar total loss combining reconstruction, perceptual, adversarial, quantizer, and (for discriminator) gradient penalty terms |
| log | dict | Dictionary of detached scalar metrics: total loss, reconstruction loss, perceptual loss, generator loss, discriminator loss, gradient penalty, adaptive weight, and real/fake logit means |
Usage Examples
# Initialize with 3D discriminator
loss_fn = VideoAutoencoderLoss(
disc_start=5000,
perceptual_weight=1.0,
adversarial_loss_weight=0.5,
grad_penalty_loss_weight=10.0,
quantizer_aux_loss_weight=1.0,
discr_3d_kwargs={
"target": "sgm.modules.autoencoding.losses.video_loss.Discriminator3D",
"params": {"dim": 64, "image_size": 256, "frame_num": 16},
},
)
# Generator update
gen_loss, gen_log = loss_fn(
inputs=video_batch, # (B, 3, T, H, W)
reconstructions=recon_video, # (B, 3, T, H, W)
optimizer_idx=0,
global_step=step,
aux_losses=quantizer_loss,
last_layer=model.decoder.conv_out.weight,
)
# Discriminator update
disc_loss, disc_log = loss_fn(
inputs=video_batch,
reconstructions=recon_video,
optimizer_idx=1,
global_step=step,
)