Implementation:Zai org CogVideo MoVQ Encoder3D
| Knowledge Sources | |
|---|---|
| Domains | Video_Generation, Autoencoding |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
A 3D video encoder that compresses video tensors into a lower-dimensional latent space using causal convolutions with configurable spatial and temporal downsampling.
Description
This module implements the Encoder3D class and its supporting building blocks for encoding video data in CogVideo's VQ-VAE pipeline. The encoder processes 5D tensors of shape (batch, channels, time, height, width) and progressively reduces spatial and temporal resolution to produce a compact latent representation.
The key architectural innovation is the CausalConv3d layer, which wraps standard 3D convolutions with causal temporal padding. This ensures that each output frame depends only on the current and past input frames, never on future frames. Three padding modes are supported: "constant" (zero-padding), "first" (repeating the first frame), and "reflect" (reflecting temporal neighbors). This causal property is critical for autoregressive or causal video generation.
DownSample3D handles spatial downsampling via strided 2D convolutions applied per-frame, with optional temporal downsampling using average pooling on all frames except the first (which is always preserved). Upsample3D performs nearest-neighbor interpolation with an optional learned convolution refinement. ResnetBlock3D uses CausalConv3d layers with GroupNorm normalization and Swish activation, supporting residual connections and optional timestep embedding injection. AttnBlock2D applies per-frame 2D spatial self-attention within the 3D tensor by reshaping to process each frame independently.
Encoder3D assembles these components into a multi-resolution downsampling network. Temporal compression is controlled by the temporal_compress_times parameter, which determines how many early downsampling levels also compress the time axis. The number of temporal compression levels is computed as log2(temporal_compress_times).
Usage
Use this encoder when you need to compress video data into a spatiotemporal latent representation for the VQ-VAE pipeline. It is the primary encoder for CogVideo's 3D video autoencoding, producing latent codes that can be quantized and later decoded.
Code Reference
Source Location
- Repository: Zai_org_CogVideo
- File: sat/sgm/modules/autoencoding/vqvae/movq_enc_3d.py
- Lines: 1-454
Signature
class Encoder3D(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
pad_mode="first",
temporal_compress_times=4,
**ignore_kwargs,
):
class CausalConv3d(nn.Module):
def __init__(
self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int, int, int]],
pad_mode="constant",
**kwargs,
):
Import
from sat.sgm.modules.autoencoding.vqvae.movq_enc_3d import Encoder3D, CausalConv3d
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| ch | int | Yes | Base channel count for the encoder network |
| out_ch | int | Yes | Number of output channels (not directly used in encoder but required by constructor) |
| ch_mult | tuple of int | No | Channel multiplier at each resolution level, default (1, 2, 4, 8) |
| num_res_blocks | int | Yes | Number of residual blocks per resolution level |
| attn_resolutions | list of int | Yes | Resolutions at which to apply spatial self-attention |
| dropout | float | No | Dropout probability, default 0.0 |
| resamp_with_conv | bool | No | Whether to use learned convolutions for resampling, default True |
| in_channels | int | Yes | Number of input channels (e.g. 3 for RGB video) |
| resolution | int | Yes | Spatial resolution of input (height/width) |
| z_channels | int | Yes | Number of latent channels in the output |
| double_z | bool | No | Whether to double z_channels for mean/variance output, default True |
| pad_mode | str | No | Temporal padding mode for CausalConv3d: "constant", "first", or "reflect", default "first" |
| temporal_compress_times | int | No | Total temporal compression factor (must be power of 2), default 4 |
Outputs
| Name | Type | Description |
|---|---|---|
| h | torch.Tensor | Latent tensor of shape (B, 2*z_channels or z_channels, T', H', W') where T', H', W' are the compressed dimensions |
Usage Examples
import torch
from sat.sgm.modules.autoencoding.vqvae.movq_enc_3d import Encoder3D
encoder = Encoder3D(
ch=128,
out_ch=3,
ch_mult=(1, 2, 4, 4),
num_res_blocks=2,
attn_resolutions=[],
in_channels=3,
resolution=256,
z_channels=4,
double_z=True,
pad_mode="first",
temporal_compress_times=4,
)
# Encode a batch of 16-frame 256x256 RGB videos
video = torch.randn(1, 3, 16, 256, 256)
latent = encoder(video)
# latent shape: (1, 8, 4, 32, 32) with 4x temporal and 8x spatial compression