Implementation:Zai org CogVideo MagViT2 Tokenizer
| Knowledge Sources | |
|---|---|
| Domains | Video_Generation, Autoencoding, Video_Tokenization |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
VideoTokenizer implements a full MagViT2-style video tokenizer with a configurable encoder-decoder architecture using causal 3D convolutions, lookup-free or finite-scalar quantization, GAN discriminator training, and VGG perceptual loss.
Description
The VideoTokenizer class (the main class within the magvit2_pytorch.py module) is a complete video-to-token-to-video system. Its architecture consists of:
Encoder: Begins with a CausalConv3d input convolution that applies temporally causal padding (padding only on the past, not future frames). The encoder then applies a configurable sequence of layers, specified by a layers tuple. Supported layer types include:
- residual / consecutive_residual: Residual units with CausalConv3d, ELU activations, and SqueezeExcite (global context attention).
- compress_space / compress_time: Strided downsampling via SpatialDownsample2x or TimeDownsample2x, each with optional anti-aliased Blur.
- attend_space / attend_time: Full attention (SpaceAttention, TimeAttention) with memory key-values and flash attention support.
- linear_attend_space: Efficient LinearSpaceAttention using Taylor series approximation.
- gateloop_time: Temporal processing via SimpleGateLoopLayer.
- cond_residual / cond_attend_*: Conditioned variants that accept external conditioning via AdaptiveRMSNorm.
A final LayerNorm is applied before quantization.
Quantization: Supports two strategies:
- LFQ (Lookup-Free Quantization): Binary thresholding with entropy-based auxiliary loss.
- FSQ (Finite Scalar Quantization): Quantization to predefined discrete levels.
Decoder: Mirrors the encoder in reverse order, using SpatialUpsample2x and TimeUpsample2x for upsampling, with corresponding residual and attention layers.
Training: The forward method supports multiple return modes: token codes, reconstructed video, generator loss (with VGG perceptual loss and adversarial loss from a Discriminator with anti-aliased downsampling), and discriminator loss with gradient penalty. Adaptive weighting balances perceptual and adversarial gradient norms. Supports optional separate first-frame encoding via dedicated 2D convolutions.
Other Key Components:
- Conv3DMod: StyleGAN2-style modulated 3D convolution for conditioning on latent vectors.
- CausalConvTranspose3d: Temporally causal transposed convolution for upsampling.
- RMSNorm / AdaptiveRMSNorm: Root-mean-square normalization with optional conditioning.
- Blur: Anti-aliased 3D blurpool filtering using kornia's
filter3d.
Usage
Use this tokenizer as the discrete bottleneck in a latent video generation pipeline. It compresses video into a compact token sequence suitable for autoregressive or diffusion-based generation models. Supports both image pretraining (4D input) and full video tokenization (5D input) through curriculum learning.
Code Reference
Source Location
- Repository: Zai_org_CogVideo
- File: sat/sgm/modules/autoencoding/magvit2_pytorch.py
- Lines: 1007-1878 (VideoTokenizer), 1883-1888 (MagViT2 placeholder)
Signature
class VideoTokenizer(Module):
def __init__(
self,
*,
image_size,
layers: Tuple[Union[str, Tuple[str, int]], ...] = (
"residual", "residual", "residual"
),
residual_conv_kernel_size=3,
num_codebooks=1,
codebook_size: Optional[int] = None,
channels=3,
init_dim=64,
max_dim=float("inf"),
dim_cond=None,
dim_cond_expansion_factor=4.0,
input_conv_kernel_size: Tuple[int, int, int] = (7, 7, 7),
output_conv_kernel_size: Tuple[int, int, int] = (3, 3, 3),
pad_mode: str = "constant",
lfq_entropy_loss_weight=0.1,
lfq_commitment_loss_weight=1.0,
lfq_diversity_gamma=2.5,
quantizer_aux_loss_weight=1.0,
lfq_activation=nn.Identity(),
use_fsq=False,
fsq_levels: Optional[List[int]] = None,
attn_dim_head=32,
attn_heads=8,
attn_dropout=0.0,
linear_attn_dim_head=8,
linear_attn_heads=16,
vgg: Optional[Module] = None,
vgg_weights: VGG16_Weights = VGG16_Weights.DEFAULT,
perceptual_loss_weight=1e-1,
discr_kwargs: Optional[dict] = None,
multiscale_discrs: Tuple[Module, ...] = tuple(),
use_gan=True,
adversarial_loss_weight=1.0,
grad_penalty_loss_weight=10.0,
multiscale_adversarial_loss_weight=1.0,
flash_attn=True,
separate_first_frame_encoding=False,
):
Import
from sat.sgm.modules.autoencoding.magvit2_pytorch import VideoTokenizer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| video_or_images | torch.Tensor | Yes | Input video (B, C, T, H, W) or images (B, C, H, W); images are expanded to single-frame video
|
| cond | Optional[torch.Tensor] | No | Conditioning tensor of shape (B, dim_cond) for conditioned layers
|
| return_loss | bool | No | If True, returns total generator loss and loss breakdown |
| return_codes | bool | No | If True, returns discrete token codes (and optionally reconstructed video) |
| return_recon | bool | No | If True with return_codes, also returns reconstructed video |
| return_discr_loss | bool | No | If True, returns discriminator loss and breakdown |
| return_recon_loss_only | bool | No | If True, returns only reconstruction MSE loss and reconstructed video |
| apply_gradient_penalty | bool | No | Whether to apply gradient penalty in discriminator loss; defaults to True |
| video_contains_first_frame | bool | No | Whether the input video includes the first frame for temporal padding; defaults to True |
| adversarial_loss_weight | Optional[float] | No | Override for the adversarial loss weight |
| multiscale_adversarial_loss_weight | Optional[float] | No | Override for multiscale adversarial loss weight |
Outputs
| Name | Type | Description |
|---|---|---|
| recon_video | torch.Tensor | Reconstructed video tensor (default return when no loss flags set) |
| codes | torch.Tensor | Discrete token indices (when return_codes=True) |
| total_loss | torch.Tensor | Scalar total loss (when return_loss=True or return_discr_loss=True) |
| loss_breakdown | LossBreakdown or DiscrLossBreakdown | Named tuple with individual loss components: recon_loss, lfq_aux_loss, perceptual_loss, adversarial_gen_loss, adaptive_adversarial_weight, etc. |
Usage Examples
# Create a video tokenizer
tokenizer = VideoTokenizer(
image_size=128,
layers=(
"residual",
"compress_space",
"residual",
"compress_time",
"attend_space",
"residual",
),
channels=3,
init_dim=64,
codebook_size=1024,
num_codebooks=1,
use_gan=True,
)
# Encode video to tokens
video = torch.randn(2, 3, 17, 128, 128)
codes = tokenizer(video, return_codes=True)
# Decode tokens back to video
recon_video = tokenizer.decode_from_code_indices(codes)
# Training: get generator loss
total_loss, loss_breakdown = tokenizer(video, return_loss=True)
# Training: get discriminator loss
discr_loss, discr_breakdown = tokenizer(video, return_discr_loss=True)