Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Zai org CogVideo Video Autoencoder Loss

From Leeroopedia


Knowledge Sources
Domains Video_Generation, Autoencoding, Adversarial_Training
Last Updated 2026-02-10 00:00 GMT

Overview

VideoAutoencoderLoss is a comprehensive video autoencoder training loss that combines MSE reconstruction, LPIPS perceptual similarity, adversarial losses from 2D and 3D discriminators, and gradient penalty regularization with adaptive weighting.

Description

This module is the primary loss computation class for CogVideo's video autoencoder. It contains and orchestrates multiple discriminator architectures and loss components:

Discriminator Architectures:

  • Discriminator (2D): A multi-layer convolutional discriminator with anti-aliased downsampling via Blur (Zhang et al. blurpool), LinearSpaceAttention blocks, and FeedForward layers at each scale. Processes single frames extracted from video.
  • Discriminator3D: A hybrid discriminator that uses DiscriminatorBlock3D (3D convolutions with 3D downsampling) for early layers to capture temporal structure, then transitions to 2D DiscriminatorBlock with attention for later layers. The temporal dimension is collapsed via reshape after the 3D stages.
  • Discriminator3DWithfirstframe: Similar to Discriminator3D but uses CausalConv3d and DownSample3D from the MOVQ encoder for first-frame-aware causal processing, collapsing the temporal dimension via mean pooling.

Loss Components:

  1. MSE reconstruction loss: Computed pixel-wise between input and reconstruction via F.mse_loss.
  2. LPIPS perceptual loss: Applied to randomly sampled frames using pick_video_frame to select frames via top-k on random noise. Weighted by perceptual_weight.
  3. Adversarial generator loss: Hinge-based generator loss from the 3D discriminator. Adaptive weighting computes the ratio of gradient norms of the perceptual loss and generator loss with respect to the last decoder layer.
  4. Discriminator loss: Hinge discriminator loss on real vs. fake 3D video logits, with optional gradient penalty regularization on real inputs.
  5. Quantizer auxiliary loss: Passed in as aux_losses and weighted by quantizer_aux_loss_weight.

The forward method dispatches on optimizer_idx: 0 for the generator (autoencoder) update and 1 for the discriminator update.

Usage

Use this loss when training a video autoencoder that requires both spatial quality (via perceptual and 2D adversarial losses) and temporal coherence (via 3D discriminator losses). This is the dedicated video loss for CogVideo's autoencoder pipeline, complementing the image-focused GeneralLPIPSWithDiscriminator.

Code Reference

Source Location

  • Repository: Zai_org_CogVideo
  • File: sat/sgm/modules/autoencoding/losses/video_loss.py
  • Lines: 516-736

Signature

class VideoAutoencoderLoss(nn.Module):
    def __init__(
        self,
        disc_start,
        perceptual_weight=1,
        adversarial_loss_weight=0,
        multiscale_adversarial_loss_weight=0,
        grad_penalty_loss_weight=0,
        quantizer_aux_loss_weight=0,
        vgg_weights=VGG16_Weights.DEFAULT,
        discr_kwargs=None,
        discr_3d_kwargs=None,
    ):

Import

from sat.sgm.modules.autoencoding.losses.video_loss import VideoAutoencoderLoss

I/O Contract

Inputs

Name Type Required Description
inputs torch.Tensor Yes Original video tensor of shape (B, C, T, H, W)
reconstructions torch.Tensor Yes Reconstructed video tensor, same shape as inputs
optimizer_idx int Yes 0 for generator (autoencoder) update, 1 for discriminator update
global_step int Yes Current training step; adversarial loss activates after disc_start
aux_losses torch.Tensor No Auxiliary losses from the quantizer; defaults to zero if None
last_layer torch.Tensor No Last decoder layer weight for adaptive weighting; required when using adaptive adversarial weighting
split str No Logging prefix string, defaults to "train"

Outputs

Name Type Description
total_loss torch.Tensor Scalar total loss combining reconstruction, perceptual, adversarial, quantizer, and (for discriminator) gradient penalty terms
log dict Dictionary of detached scalar metrics: total loss, reconstruction loss, perceptual loss, generator loss, discriminator loss, gradient penalty, adaptive weight, and real/fake logit means

Usage Examples

# Initialize with 3D discriminator
loss_fn = VideoAutoencoderLoss(
    disc_start=5000,
    perceptual_weight=1.0,
    adversarial_loss_weight=0.5,
    grad_penalty_loss_weight=10.0,
    quantizer_aux_loss_weight=1.0,
    discr_3d_kwargs={
        "target": "sgm.modules.autoencoding.losses.video_loss.Discriminator3D",
        "params": {"dim": 64, "image_size": 256, "frame_num": 16},
    },
)

# Generator update
gen_loss, gen_log = loss_fn(
    inputs=video_batch,           # (B, 3, T, H, W)
    reconstructions=recon_video,  # (B, 3, T, H, W)
    optimizer_idx=0,
    global_step=step,
    aux_losses=quantizer_loss,
    last_layer=model.decoder.conv_out.weight,
)

# Discriminator update
disc_loss, disc_log = loss_fn(
    inputs=video_batch,
    reconstructions=recon_video,
    optimizer_idx=1,
    global_step=step,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment