Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Zai org CogVideo MagViT2 Tokenizer

From Leeroopedia


Knowledge Sources
Domains Video_Generation, Autoencoding, Video_Tokenization
Last Updated 2026-02-10 00:00 GMT

Overview

VideoTokenizer implements a full MagViT2-style video tokenizer with a configurable encoder-decoder architecture using causal 3D convolutions, lookup-free or finite-scalar quantization, GAN discriminator training, and VGG perceptual loss.

Description

The VideoTokenizer class (the main class within the magvit2_pytorch.py module) is a complete video-to-token-to-video system. Its architecture consists of:

Encoder: Begins with a CausalConv3d input convolution that applies temporally causal padding (padding only on the past, not future frames). The encoder then applies a configurable sequence of layers, specified by a layers tuple. Supported layer types include:

  • residual / consecutive_residual: Residual units with CausalConv3d, ELU activations, and SqueezeExcite (global context attention).
  • compress_space / compress_time: Strided downsampling via SpatialDownsample2x or TimeDownsample2x, each with optional anti-aliased Blur.
  • attend_space / attend_time: Full attention (SpaceAttention, TimeAttention) with memory key-values and flash attention support.
  • linear_attend_space: Efficient LinearSpaceAttention using Taylor series approximation.
  • gateloop_time: Temporal processing via SimpleGateLoopLayer.
  • cond_residual / cond_attend_*: Conditioned variants that accept external conditioning via AdaptiveRMSNorm.

A final LayerNorm is applied before quantization.

Quantization: Supports two strategies:

  • LFQ (Lookup-Free Quantization): Binary thresholding with entropy-based auxiliary loss.
  • FSQ (Finite Scalar Quantization): Quantization to predefined discrete levels.

Decoder: Mirrors the encoder in reverse order, using SpatialUpsample2x and TimeUpsample2x for upsampling, with corresponding residual and attention layers.

Training: The forward method supports multiple return modes: token codes, reconstructed video, generator loss (with VGG perceptual loss and adversarial loss from a Discriminator with anti-aliased downsampling), and discriminator loss with gradient penalty. Adaptive weighting balances perceptual and adversarial gradient norms. Supports optional separate first-frame encoding via dedicated 2D convolutions.

Other Key Components:

  • Conv3DMod: StyleGAN2-style modulated 3D convolution for conditioning on latent vectors.
  • CausalConvTranspose3d: Temporally causal transposed convolution for upsampling.
  • RMSNorm / AdaptiveRMSNorm: Root-mean-square normalization with optional conditioning.
  • Blur: Anti-aliased 3D blurpool filtering using kornia's filter3d.

Usage

Use this tokenizer as the discrete bottleneck in a latent video generation pipeline. It compresses video into a compact token sequence suitable for autoregressive or diffusion-based generation models. Supports both image pretraining (4D input) and full video tokenization (5D input) through curriculum learning.

Code Reference

Source Location

  • Repository: Zai_org_CogVideo
  • File: sat/sgm/modules/autoencoding/magvit2_pytorch.py
  • Lines: 1007-1878 (VideoTokenizer), 1883-1888 (MagViT2 placeholder)

Signature

class VideoTokenizer(Module):
    def __init__(
        self,
        *,
        image_size,
        layers: Tuple[Union[str, Tuple[str, int]], ...] = (
            "residual", "residual", "residual"
        ),
        residual_conv_kernel_size=3,
        num_codebooks=1,
        codebook_size: Optional[int] = None,
        channels=3,
        init_dim=64,
        max_dim=float("inf"),
        dim_cond=None,
        dim_cond_expansion_factor=4.0,
        input_conv_kernel_size: Tuple[int, int, int] = (7, 7, 7),
        output_conv_kernel_size: Tuple[int, int, int] = (3, 3, 3),
        pad_mode: str = "constant",
        lfq_entropy_loss_weight=0.1,
        lfq_commitment_loss_weight=1.0,
        lfq_diversity_gamma=2.5,
        quantizer_aux_loss_weight=1.0,
        lfq_activation=nn.Identity(),
        use_fsq=False,
        fsq_levels: Optional[List[int]] = None,
        attn_dim_head=32,
        attn_heads=8,
        attn_dropout=0.0,
        linear_attn_dim_head=8,
        linear_attn_heads=16,
        vgg: Optional[Module] = None,
        vgg_weights: VGG16_Weights = VGG16_Weights.DEFAULT,
        perceptual_loss_weight=1e-1,
        discr_kwargs: Optional[dict] = None,
        multiscale_discrs: Tuple[Module, ...] = tuple(),
        use_gan=True,
        adversarial_loss_weight=1.0,
        grad_penalty_loss_weight=10.0,
        multiscale_adversarial_loss_weight=1.0,
        flash_attn=True,
        separate_first_frame_encoding=False,
    ):

Import

from sat.sgm.modules.autoencoding.magvit2_pytorch import VideoTokenizer

I/O Contract

Inputs

Name Type Required Description
video_or_images torch.Tensor Yes Input video (B, C, T, H, W) or images (B, C, H, W); images are expanded to single-frame video
cond Optional[torch.Tensor] No Conditioning tensor of shape (B, dim_cond) for conditioned layers
return_loss bool No If True, returns total generator loss and loss breakdown
return_codes bool No If True, returns discrete token codes (and optionally reconstructed video)
return_recon bool No If True with return_codes, also returns reconstructed video
return_discr_loss bool No If True, returns discriminator loss and breakdown
return_recon_loss_only bool No If True, returns only reconstruction MSE loss and reconstructed video
apply_gradient_penalty bool No Whether to apply gradient penalty in discriminator loss; defaults to True
video_contains_first_frame bool No Whether the input video includes the first frame for temporal padding; defaults to True
adversarial_loss_weight Optional[float] No Override for the adversarial loss weight
multiscale_adversarial_loss_weight Optional[float] No Override for multiscale adversarial loss weight

Outputs

Name Type Description
recon_video torch.Tensor Reconstructed video tensor (default return when no loss flags set)
codes torch.Tensor Discrete token indices (when return_codes=True)
total_loss torch.Tensor Scalar total loss (when return_loss=True or return_discr_loss=True)
loss_breakdown LossBreakdown or DiscrLossBreakdown Named tuple with individual loss components: recon_loss, lfq_aux_loss, perceptual_loss, adversarial_gen_loss, adaptive_adversarial_weight, etc.

Usage Examples

# Create a video tokenizer
tokenizer = VideoTokenizer(
    image_size=128,
    layers=(
        "residual",
        "compress_space",
        "residual",
        "compress_time",
        "attend_space",
        "residual",
    ),
    channels=3,
    init_dim=64,
    codebook_size=1024,
    num_codebooks=1,
    use_gan=True,
)

# Encode video to tokens
video = torch.randn(2, 3, 17, 128, 128)
codes = tokenizer(video, return_codes=True)

# Decode tokens back to video
recon_video = tokenizer.decode_from_code_indices(codes)

# Training: get generator loss
total_loss, loss_breakdown = tokenizer(video, return_loss=True)

# Training: get discriminator loss
discr_loss, discr_breakdown = tokenizer(video, return_discr_loss=True)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment