Implementation:Zai org CogVideo VectorQuantizer2
| Knowledge Sources | |
|---|---|
| Domains | Video_Generation, Quantization, Autoencoding |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
VectorQuantizer2 is an improved vector quantization module for VQ-VAE that avoids costly pairwise distance matrices and supports optional index remapping for codebook restriction.
Description
The VectorQuantizer2 class implements nearest-neighbor vector quantization using the expanded Euclidean distance form z^2 + e^2 - 2*z*e, which avoids materializing a full pairwise distance matrix. The encoder output is compared against all codebook embeddings, and each spatial position is assigned to its nearest codebook entry.
Gradient flow through the discrete bottleneck is preserved using the straight-through estimator: z_q = z + (z_q - z).detach(). The training loss consists of two terms: a commitment loss that encourages the encoder to produce outputs close to codebook entries, and an embedding loss that pulls codebook entries toward encoder outputs. A legacy flag controls which term the beta coefficient is applied to, maintaining backward compatibility with an earlier implementation bug.
The module also supports optional index remapping via an externally provided mapping file, which restricts the effective codebook to a subset of used indices. The file also includes GumbelQuantize, a Gumbel-Softmax quantizer that provides a differentiable alternative using categorical reparameterization.
Usage
Use VectorQuantizer2 as the discrete bottleneck in VQ-VAE architectures for video or image compression. It is appropriate when you need a learned codebook with explicit nearest-neighbor assignment and want efficient distance computation for large codebook sizes.
Code Reference
Source Location
- Repository: Zai_org_CogVideo
- File:
sat/sgm/modules/autoencoding/vqvae/quantize.py
Signature
class VectorQuantizer2(nn.Module):
def __init__(
self,
n_e,
e_dim,
beta,
remap=None,
unknown_index="random",
sane_index_shape=False,
legacy=True,
):
Import
from sat.sgm.modules.autoencoding.vqvae.quantize import VectorQuantizer2
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
n_e |
int |
Yes | Number of codebook entries (vocabulary size). |
e_dim |
int |
Yes | Dimension of each codebook embedding vector. |
beta |
float |
Yes | Weighting factor for the commitment/embedding loss term. |
remap |
str or None |
No | Path to a numpy file containing used codebook indices for remapping. Default: None.
|
unknown_index |
str or int |
No | Strategy for handling indices not in the remap set. Options: "random", "extra", or an integer. Default: "random".
|
sane_index_shape |
bool |
No | If True, returns indices reshaped to (B, H, W) instead of flattened. Default: False.
|
legacy |
bool |
No | If True, applies beta to the embedding loss (buggy legacy behavior). If False, applies beta to the commitment loss. Default: True.
|
Forward Inputs
| Name | Type | Required | Description |
|---|---|---|---|
z |
Tensor |
Yes | Encoder output tensor of shape (B, C, H, W).
|
temp |
None or float |
No | Temperature parameter (interface-only, must be None or 1.0).
|
rescale_logits |
bool |
No | Must be False (interface compatibility).
|
return_logits |
bool |
No | Must be False (interface compatibility).
|
Outputs
| Name | Type | Description |
|---|---|---|
z_q |
Tensor |
Quantized tensor with shape (B, C, H, W), same as input.
|
loss |
Tensor |
Combined commitment and embedding loss scalar. |
(perplexity, min_encodings, min_encoding_indices) |
Tuple |
Perplexity (always None), min encodings (always None), and the integer indices of selected codebook entries.
|
Key Methods
get_codebook_entry(indices, shape)
Decodes codebook indices back to embedding vectors. Accepts a flat index tensor and the target shape (B, H, W, C), returning a tensor of shape (B, C, H, W).
remap_to_used(inds)
Maps raw codebook indices to the restricted subset of used indices, handling unknown indices via the configured strategy.
unmap_to_all(inds)
Inverse of remap_to_used: converts restricted indices back to the full codebook index space.
Usage Examples
from sat.sgm.modules.autoencoding.vqvae.quantize import VectorQuantizer2
# Create quantizer with 1024 codebook entries of dimension 256
quantizer = VectorQuantizer2(n_e=1024, e_dim=256, beta=0.25)
# Quantize encoder output (B, C, H, W)
z = torch.randn(4, 256, 16, 16)
z_q, loss, (perplexity, min_encodings, indices) = quantizer(z)
# z_q.shape: (4, 256, 16, 16)
# Decode indices back to embeddings
z_reconstructed = quantizer.get_codebook_entry(
indices.flatten(), shape=(4, 16, 16, 256)
)