Implementation:Zai org CogVideo VQVAE Blocks
| Knowledge Sources | |
|---|---|
| Domains | Video_Generation, Autoencoding |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Standard 2D encoder and decoder building blocks for VQ-VAE, implementing a multi-resolution convolutional autoencoder with residual blocks, self-attention, and progressive spatial downsampling/upsampling.
Description
This module provides the foundational 2D Encoder and Decoder classes along with their constituent building blocks for the VQ-VAE architecture. These blocks define the standard spatial autoencoder upon which CogVideo's video-specific extensions are built.
The Encoder progressively downsamples input images through resolution levels defined by ch_mult. At each level, num_res_blocks residual blocks are applied, each using GroupNorm normalization with 32 groups, Swish activation (x * sigmoid(x)), and 3x3 convolutions. When the current resolution appears in attn_resolutions, an AttnBlock is appended after the residual block. Spatial downsampling between levels uses strided 2D convolutions with asymmetric zero-padding (padding only on the right and bottom edges). A middle section at the bottleneck contains two ResnetBlocks with an AttnBlock in between. The final output is produced by GroupNorm, Swish, and a 3x3 convolution projecting to 2*z_channels (for mean and log-variance) or z_channels.
The Decoder mirrors the encoder with progressive upsampling via nearest-neighbor interpolation followed by optional learned 3x3 convolutions. It processes from the lowest resolution upward, applying num_res_blocks + 1 residual blocks per level. Both Encoder and Decoder provide a forward_with_features_output method that returns intermediate feature maps at each layer.
AttnBlock implements single-head spatial self-attention with the scaling applied to queries before the matrix multiplication (rather than after), which improves numerical stability in fp16. The attention computation uses 1x1 convolutions for Q, K, V projections.
Usage
Use these blocks as the 2D spatial backbone for VQ-VAE autoencoders. They serve as the base architecture that can be extended with temporal layers for video processing.
Code Reference
Source Location
- Repository: Zai_org_CogVideo
- File: sat/sgm/modules/autoencoding/vqvae/vqvae_blocks.py
- Lines: 1-424
Signature
class Encoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
**ignore_kwargs,
):
class Decoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
**ignorekwargs,
):
Import
from sat.sgm.modules.autoencoding.vqvae.vqvae_blocks import Encoder, Decoder
I/O Contract
Inputs (Encoder)
| Name | Type | Required | Description |
|---|---|---|---|
| ch | int | Yes | Base channel count for the network |
| out_ch | int | Yes | Number of output channels (used by the decoder counterpart) |
| ch_mult | tuple of int | No | Channel multiplier per resolution level, default (1, 2, 4, 8) |
| num_res_blocks | int | Yes | Number of residual blocks per resolution level |
| attn_resolutions | list of int | Yes | Resolutions at which to apply self-attention |
| dropout | float | No | Dropout probability, default 0.0 |
| resamp_with_conv | bool | No | Use learned convolutions for downsampling, default True |
| in_channels | int | Yes | Number of input channels (e.g. 3 for RGB) |
| resolution | int | Yes | Input spatial resolution |
| z_channels | int | Yes | Number of latent channels |
| double_z | bool | No | Double z_channels for mean/variance output, default True |
Inputs (Decoder)
| Name | Type | Required | Description |
|---|---|---|---|
| ch | int | Yes | Base channel count |
| out_ch | int | Yes | Number of output channels |
| ch_mult | tuple of int | No | Channel multiplier per level, default (1, 2, 4, 8) |
| num_res_blocks | int | Yes | Number of residual blocks per level |
| attn_resolutions | list of int | Yes | Resolutions for self-attention |
| in_channels | int | Yes | Number of input channels |
| resolution | int | Yes | Target output resolution |
| z_channels | int | Yes | Number of latent input channels |
| give_pre_end | bool | No | Return features before final projection, default False |
Outputs
| Name | Type | Description |
|---|---|---|
| h (Encoder) | torch.Tensor | Latent tensor of shape (B, 2*z_channels or z_channels, H', W') |
| h (Decoder) | torch.Tensor | Reconstructed image tensor of shape (B, out_ch, resolution, resolution) |
Usage Examples
import torch
from sat.sgm.modules.autoencoding.vqvae.vqvae_blocks import Encoder, Decoder
encoder = Encoder(
ch=128, out_ch=3, ch_mult=(1, 2, 4, 4),
num_res_blocks=2, attn_resolutions=[32],
in_channels=3, resolution=256, z_channels=4,
)
decoder = Decoder(
ch=128, out_ch=3, ch_mult=(1, 2, 4, 4),
num_res_blocks=2, attn_resolutions=[32],
in_channels=256, resolution=256, z_channels=4,
)
# Encode
x = torch.randn(1, 3, 256, 256)
z = encoder(x) # shape: (1, 8, 32, 32)
# Decode
recon = decoder(z[:, :4, :, :]) # shape: (1, 3, 256, 256)
# Extract intermediate features
z, features = encoder.forward_with_features_output(x)