Implementation:Zai org CogVideo MoVQ Decoder3D
| Knowledge Sources | |
|---|---|
| Domains | Video_Generation, Autoencoding, 3D_Decoding |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
MOVQDecoder3D is a 3D video decoder for the MoVQ (Moving Vector Quantized) VAE that reconstructs video from quantized latent representations using causal 3D convolutions and spatial normalization conditioned on the quantized codes themselves.
Description
The MOVQDecoder3D class implements a multi-resolution upsampling decoder architecture designed specifically for 3D video reconstruction from quantized latents. Its distinguishing feature is spatial normalization conditioned on quantized codes (similar to SPADE), where the quantized tensor zq is used at every normalization layer to modulate feature maps through learned affine transforms.
The architecture is built from several key components:
- SpatialNorm3D: Performs adaptive normalization by interpolating the quantized tensor to match feature map dimensions, with special handling for the first frame versus remaining frames to maintain causal temporal processing. Group normalization is applied to the features, then learned scale and bias are computed from zq via 1x1 CausalConv3d layers.
- ResnetBlock3D: Uses CausalConv3d for all convolutions to ensure temporal causality (each frame only depends on previous frames). Features are normalized with SpatialNorm3D conditioned on zq, and optional timestep embeddings can be injected.
- AttnBlock2D: Performs spatial self-attention per frame by reshaping the 5D tensor to (B*T, C, H, W), computing Q/K/V attention with scaled dot-product, then reshaping back. Normalization is conditioned on zq.
The decoder follows a U-Net-like upsampling path with configurable channel multipliers, resolution levels, and attention placement. Temporal upsampling is controlled by temporal_compress_times, which determines how many resolution levels include temporal upsampling via Upsample3D.
The file also provides NewDecoder3D, an extended variant that adds an optional post_quant_conv layer (a CausalConv3d applied to the latent before the main decoder) for post-quantization refinement.
Usage
Use MOVQDecoder3D as the decoder component of a 3D VQ-VAE pipeline for video generation. It is designed to work with CausalConv3d-based encoders and vector-quantized latent spaces. Use NewDecoder3D when post-quantization convolution is needed to bridge a mismatch between quantized code channels and the decoder's expected input channels.
Code Reference
Source Location
- Repository: Zai_org_CogVideo
- File: sat/sgm/modules/autoencoding/vqvae/movq_dec_3d.py
- Lines: 205-341 (MOVQDecoder3D), 343-505 (NewDecoder3D), 48-84 (SpatialNorm3D), 100-156 (ResnetBlock3D), 159-202 (AttnBlock2D)
Signature
class MOVQDecoder3D(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
zq_ch=None,
add_conv=False,
pad_mode="first",
temporal_compress_times=4,
**ignorekwargs,
):
...
def forward(self, z, use_cp=False):
...
def get_last_layer(self):
...
class NewDecoder3D(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
zq_ch=None,
add_conv=False,
pad_mode="first",
temporal_compress_times=4,
post_quant_conv=False,
**ignorekwargs,
):
...
def forward(self, z):
...
Import
from sat.sgm.modules.autoencoding.vqvae.movq_dec_3d import MOVQDecoder3D, NewDecoder3D, SpatialNorm3D
I/O Contract
MOVQDecoder3D Constructor Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| ch | int | Yes | Base channel count for the decoder |
| out_ch | int | Yes | Number of output channels (e.g., 3 for RGB) |
| ch_mult | tuple | No (default (1,2,4,8)) | Channel multipliers per resolution level |
| num_res_blocks | int | Yes | Number of residual blocks per resolution level |
| attn_resolutions | list | Yes | Resolutions at which to apply spatial attention |
| dropout | float | No (default 0.0) | Dropout rate in residual blocks |
| resamp_with_conv | bool | No (default True) | Whether to use convolution-based upsampling |
| in_channels | int | Yes | Number of input channels to the model |
| resolution | int | Yes | Target spatial resolution of the output |
| z_channels | int | Yes | Number of channels in the latent representation |
| give_pre_end | bool | No (default False) | If True, return features before the final normalization and convolution |
| zq_ch | int | No (default None) | Channel count for the quantized conditioning tensor; defaults to z_channels |
| add_conv | bool | No (default False) | Whether to add an extra convolution in SpatialNorm3D for the zq conditioning |
| pad_mode | str | No (default "first") | Padding mode for CausalConv3d: "first" or "constant" |
| temporal_compress_times | int | No (default 4) | Temporal compression factor; determines how many levels include temporal upsampling (log2 of this value) |
Forward Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| z | torch.Tensor | Yes | Quantized latent tensor of shape (B, C, T, H, W) |
| use_cp | bool | No (default False) | Checkpoint flag (MOVQDecoder3D only) |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor | Reconstructed video tensor of shape (B, out_ch, T_out, H_out, W_out) with temporally and spatially upsampled dimensions |
| get_last_layer | torch.Tensor | Returns self.conv_out.conv.weight for use in discriminator-based training losses |
Usage Examples
from sat.sgm.modules.autoencoding.vqvae.movq_dec_3d import MOVQDecoder3D
# Create a 3D MoVQ decoder
decoder = MOVQDecoder3D(
ch=128,
out_ch=3,
ch_mult=(1, 2, 4, 4),
num_res_blocks=2,
attn_resolutions=[16],
dropout=0.0,
in_channels=256,
resolution=256,
z_channels=256,
zq_ch=256,
add_conv=False,
pad_mode="first",
temporal_compress_times=4,
)
# Decode from a quantized latent (B=2, C=256, T=4, H=16, W=16)
z_q = torch.randn(2, 256, 4, 16, 16)
video = decoder(z_q)
# Output: (2, 3, T_out, 256, 256) where T_out depends on temporal_compress_times
# NewDecoder3D with post-quantization convolution
from sat.sgm.modules.autoencoding.vqvae.movq_dec_3d import NewDecoder3D
decoder_v2 = NewDecoder3D(
ch=128,
out_ch=3,
ch_mult=(1, 2, 4, 4),
num_res_blocks=2,
attn_resolutions=[16],
in_channels=256,
resolution=256,
z_channels=256,
post_quant_conv=True,
)
video = decoder_v2(z_q)