Implementation:Zai org CogVideo MoVQ 2D Modules
| Knowledge Sources | |
|---|---|
| Domains | Video_Generation, Autoencoding |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
A collection of 2D decoder modules for MoVQ (Moving Vector Quantized) VAE that reconstructs images from quantized latent representations using spatially-adaptive normalization conditioned on quantized codes.
Description
This module provides the MOVQDecoder class and its supporting building blocks for image reconstruction in CogVideo's VQ-VAE pipeline. The central architectural innovation is SpatialNorm, a normalization layer that applies standard GroupNorm and then modulates the normalized features using learned affine transforms derived from the quantized code tensor zq. This is conceptually similar to SPADE-style normalization, where external spatial information guides the normalization process.
SpatialNorm interpolates the quantized tensor zq to match the spatial dimensions of the feature map, optionally applies an additional convolution, and then produces per-channel scale (conv_y) and bias (conv_b) parameters. The normalized features are transformed as: output = GroupNorm(f) * conv_y(zq) + conv_b(zq).
ResnetBlock uses SpatialNorm for both normalization layers, taking zq as an additional conditioning input alongside the standard feature tensor and optional timestep embedding. AttnBlock similarly uses SpatialNorm before computing single-head spatial self-attention with scaled dot-product attention.
MOVQDecoder constructs a multi-resolution upsampling decoder starting from the lowest resolution. It includes a middle section with two ResnetBlocks and one AttnBlock, followed by progressive upsampling stages. Each stage contains configurable numbers of ResnetBlocks and optional attention blocks at specified resolutions. The decoder also provides forward_with_features_output for extracting intermediate feature maps at every layer, useful for debugging or perceptual loss computation.
Usage
Use this decoder for reconstructing images from quantized latent codes in the MoVQ VQ-VAE pipeline. The spatially-adaptive normalization enables the decoder to leverage discrete codebook information throughout the reconstruction process.
Code Reference
Source Location
- Repository: Zai_org_CogVideo
- File: sat/sgm/modules/autoencoding/vqvae/movq_modules.py
- Lines: 1-382
Signature
class MOVQDecoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
zq_ch=None,
add_conv=False,
**ignorekwargs,
):
class SpatialNorm(nn.Module):
def __init__(
self,
f_channels,
zq_channels,
norm_layer=nn.GroupNorm,
freeze_norm_layer=False,
add_conv=False,
**norm_layer_params,
):
Import
from sat.sgm.modules.autoencoding.vqvae.movq_modules import MOVQDecoder, SpatialNorm
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| ch | int | Yes | Base channel count for the decoder network |
| out_ch | int | Yes | Number of output channels (e.g. 3 for RGB) |
| ch_mult | tuple of int | No | Channel multiplier at each resolution level, default (1, 2, 4, 8) |
| num_res_blocks | int | Yes | Number of residual blocks per resolution level |
| attn_resolutions | list of int | Yes | Resolutions at which to apply spatial self-attention |
| dropout | float | No | Dropout probability, default 0.0 |
| resamp_with_conv | bool | No | Whether to use learned convolutions for upsampling, default True |
| in_channels | int | Yes | Number of latent input channels |
| resolution | int | Yes | Target spatial resolution of the output |
| z_channels | int | Yes | Number of latent channels at the input |
| give_pre_end | bool | No | If True, return features before final norm and convolution, default False |
| zq_ch | int | Yes | Number of channels in the quantized code tensor for SpatialNorm conditioning |
| add_conv | bool | No | Whether to apply additional convolution on zq in SpatialNorm, default False |
Outputs
| Name | Type | Description |
|---|---|---|
| h | torch.Tensor | Reconstructed image tensor of shape (B, out_ch, resolution, resolution) |
| output_features | dict (optional) | Dictionary of intermediate feature maps keyed by layer name, returned by forward_with_features_output |
Usage Examples
import torch
from sat.sgm.modules.autoencoding.vqvae.movq_modules import MOVQDecoder
decoder = MOVQDecoder(
ch=128,
out_ch=3,
ch_mult=(1, 2, 4, 4),
num_res_blocks=2,
attn_resolutions=[32],
in_channels=256,
resolution=256,
z_channels=4,
zq_ch=4,
add_conv=False,
)
# Decode from latent
z = torch.randn(1, 4, 32, 32)
zq = torch.randn(1, 4, 32, 32) # quantized code tensor
output = decoder(z, zq)
# output shape: (1, 3, 256, 256)
# Get intermediate features
output, features = decoder.forward_with_features_output(z, zq)