Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Zai org CogVideo MoVQ Decoder3D

From Leeroopedia


Knowledge Sources
Domains Video_Generation, Autoencoding, 3D_Decoding
Last Updated 2026-02-10 00:00 GMT

Overview

MOVQDecoder3D is a 3D video decoder for the MoVQ (Moving Vector Quantized) VAE that reconstructs video from quantized latent representations using causal 3D convolutions and spatial normalization conditioned on the quantized codes themselves.

Description

The MOVQDecoder3D class implements a multi-resolution upsampling decoder architecture designed specifically for 3D video reconstruction from quantized latents. Its distinguishing feature is spatial normalization conditioned on quantized codes (similar to SPADE), where the quantized tensor zq is used at every normalization layer to modulate feature maps through learned affine transforms.

The architecture is built from several key components:

  • SpatialNorm3D: Performs adaptive normalization by interpolating the quantized tensor to match feature map dimensions, with special handling for the first frame versus remaining frames to maintain causal temporal processing. Group normalization is applied to the features, then learned scale and bias are computed from zq via 1x1 CausalConv3d layers.
  • ResnetBlock3D: Uses CausalConv3d for all convolutions to ensure temporal causality (each frame only depends on previous frames). Features are normalized with SpatialNorm3D conditioned on zq, and optional timestep embeddings can be injected.
  • AttnBlock2D: Performs spatial self-attention per frame by reshaping the 5D tensor to (B*T, C, H, W), computing Q/K/V attention with scaled dot-product, then reshaping back. Normalization is conditioned on zq.

The decoder follows a U-Net-like upsampling path with configurable channel multipliers, resolution levels, and attention placement. Temporal upsampling is controlled by temporal_compress_times, which determines how many resolution levels include temporal upsampling via Upsample3D.

The file also provides NewDecoder3D, an extended variant that adds an optional post_quant_conv layer (a CausalConv3d applied to the latent before the main decoder) for post-quantization refinement.

Usage

Use MOVQDecoder3D as the decoder component of a 3D VQ-VAE pipeline for video generation. It is designed to work with CausalConv3d-based encoders and vector-quantized latent spaces. Use NewDecoder3D when post-quantization convolution is needed to bridge a mismatch between quantized code channels and the decoder's expected input channels.

Code Reference

Source Location

  • Repository: Zai_org_CogVideo
  • File: sat/sgm/modules/autoencoding/vqvae/movq_dec_3d.py
  • Lines: 205-341 (MOVQDecoder3D), 343-505 (NewDecoder3D), 48-84 (SpatialNorm3D), 100-156 (ResnetBlock3D), 159-202 (AttnBlock2D)

Signature

class MOVQDecoder3D(nn.Module):
    def __init__(
        self,
        *,
        ch,
        out_ch,
        ch_mult=(1, 2, 4, 8),
        num_res_blocks,
        attn_resolutions,
        dropout=0.0,
        resamp_with_conv=True,
        in_channels,
        resolution,
        z_channels,
        give_pre_end=False,
        zq_ch=None,
        add_conv=False,
        pad_mode="first",
        temporal_compress_times=4,
        **ignorekwargs,
    ):
        ...

    def forward(self, z, use_cp=False):
        ...

    def get_last_layer(self):
        ...
class NewDecoder3D(nn.Module):
    def __init__(
        self,
        *,
        ch,
        out_ch,
        ch_mult=(1, 2, 4, 8),
        num_res_blocks,
        attn_resolutions,
        dropout=0.0,
        resamp_with_conv=True,
        in_channels,
        resolution,
        z_channels,
        give_pre_end=False,
        zq_ch=None,
        add_conv=False,
        pad_mode="first",
        temporal_compress_times=4,
        post_quant_conv=False,
        **ignorekwargs,
    ):
        ...

    def forward(self, z):
        ...

Import

from sat.sgm.modules.autoencoding.vqvae.movq_dec_3d import MOVQDecoder3D, NewDecoder3D, SpatialNorm3D

I/O Contract

MOVQDecoder3D Constructor Inputs

Name Type Required Description
ch int Yes Base channel count for the decoder
out_ch int Yes Number of output channels (e.g., 3 for RGB)
ch_mult tuple No (default (1,2,4,8)) Channel multipliers per resolution level
num_res_blocks int Yes Number of residual blocks per resolution level
attn_resolutions list Yes Resolutions at which to apply spatial attention
dropout float No (default 0.0) Dropout rate in residual blocks
resamp_with_conv bool No (default True) Whether to use convolution-based upsampling
in_channels int Yes Number of input channels to the model
resolution int Yes Target spatial resolution of the output
z_channels int Yes Number of channels in the latent representation
give_pre_end bool No (default False) If True, return features before the final normalization and convolution
zq_ch int No (default None) Channel count for the quantized conditioning tensor; defaults to z_channels
add_conv bool No (default False) Whether to add an extra convolution in SpatialNorm3D for the zq conditioning
pad_mode str No (default "first") Padding mode for CausalConv3d: "first" or "constant"
temporal_compress_times int No (default 4) Temporal compression factor; determines how many levels include temporal upsampling (log2 of this value)

Forward Inputs

Name Type Required Description
z torch.Tensor Yes Quantized latent tensor of shape (B, C, T, H, W)
use_cp bool No (default False) Checkpoint flag (MOVQDecoder3D only)

Outputs

Name Type Description
output torch.Tensor Reconstructed video tensor of shape (B, out_ch, T_out, H_out, W_out) with temporally and spatially upsampled dimensions
get_last_layer torch.Tensor Returns self.conv_out.conv.weight for use in discriminator-based training losses

Usage Examples

from sat.sgm.modules.autoencoding.vqvae.movq_dec_3d import MOVQDecoder3D

# Create a 3D MoVQ decoder
decoder = MOVQDecoder3D(
    ch=128,
    out_ch=3,
    ch_mult=(1, 2, 4, 4),
    num_res_blocks=2,
    attn_resolutions=[16],
    dropout=0.0,
    in_channels=256,
    resolution=256,
    z_channels=256,
    zq_ch=256,
    add_conv=False,
    pad_mode="first",
    temporal_compress_times=4,
)

# Decode from a quantized latent (B=2, C=256, T=4, H=16, W=16)
z_q = torch.randn(2, 256, 4, 16, 16)
video = decoder(z_q)
# Output: (2, 3, T_out, 256, 256) where T_out depends on temporal_compress_times

# NewDecoder3D with post-quantization convolution
from sat.sgm.modules.autoencoding.vqvae.movq_dec_3d import NewDecoder3D

decoder_v2 = NewDecoder3D(
    ch=128,
    out_ch=3,
    ch_mult=(1, 2, 4, 4),
    num_res_blocks=2,
    attn_resolutions=[16],
    in_channels=256,
    resolution=256,
    z_channels=256,
    post_quant_conv=True,
)
video = decoder_v2(z_q)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment