Implementation:Zai org CogVideo EMAVectorQuantizer
| Knowledge Sources | |
|---|---|
| Domains | Video_Generation, Autoencoding, Vector_Quantization |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
EMAVectorQuantizer is a vector quantization module that updates its codebook embeddings via exponential moving averages (EMA) of encoder outputs rather than through gradient descent, providing more stable codebook learning for VQ-VAE architectures.
Description
The EMAVectorQuantizer class extends AbstractQuantizer to implement a vector quantization bottleneck where the codebook is updated using exponential moving average tracking of cluster assignments and embedding sums. During the forward pass, input feature vectors are mapped to their nearest codebook entry by computing Euclidean distances between the flattened input and all codebook embeddings. The loss is computed as the commitment loss only (beta * MSE between quantized output and encoder output), since the codebook itself is updated through EMA rather than gradients.
The codebook update mechanism is handled by the companion EmbeddingEMA module, which maintains running averages of cluster sizes and embedding sums. During training, when the embedding's update flag is set, the cluster size and embedding average are updated using the configured decay rate. The codebook weights are then recomputed as the normalized embedding averages, with Laplace smoothing applied to prevent division by zero for unused codes.
The module also supports index remapping (inherited from AbstractQuantizer), which restricts the effective codebook to a subset of used indices loaded from a numpy file. Unknown indices can be handled by mapping to a random used index or a specific extra token.
Usage
Use EMAVectorQuantizer when training a VQ-VAE or similar discrete latent model and you want more stable codebook updates without relying on straight-through gradient estimation for the codebook. This is particularly beneficial when codebook utilization is a concern, as EMA-based updates tend to keep more codebook entries active. It is the preferred quantizer when training stability is more important than gradient-based fine-grained control of codebook entries.
Code Reference
Source Location
- Repository: Zai_org_CogVideo
- File: sat/sgm/modules/autoencoding/regularizers/quantize.py
- Lines: 336-424 (EMAVectorQuantizer), 308-333 (EmbeddingEMA)
Signature
class EMAVectorQuantizer(AbstractQuantizer):
def __init__(
self,
n_embed: int,
embedding_dim: int,
beta: float,
decay: float = 0.99,
eps: float = 1e-5,
remap: Optional[str] = None,
unknown_index: str = "random",
loss_key: str = "loss/vq",
):
...
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
...
class EmbeddingEMA(nn.Module):
def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
...
def forward(self, embed_id):
...
def cluster_size_ema_update(self, new_cluster_size):
...
def embed_avg_ema_update(self, new_embed_avg):
...
def weight_update(self, num_tokens):
...
Import
from sat.sgm.modules.autoencoding.regularizers.quantize import EMAVectorQuantizer, EmbeddingEMA
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| n_embed | int | Yes | Number of codebook entries (vocabulary size) |
| embedding_dim | int | Yes | Dimensionality of each codebook vector |
| beta | float | Yes | Commitment loss weight (scales MSE between z_q and z) |
| decay | float | No (default 0.99) | EMA decay rate for codebook updates |
| eps | float | No (default 1e-5) | Epsilon for Laplace smoothing in weight normalization |
| remap | Optional[str] | No (default None) | Path to numpy file of used indices for codebook restriction |
| unknown_index | str | No (default "random") | How to handle unknown indices: "random", "extra", or an integer |
| loss_key | str | No (default "loss/vq") | Key used in the output loss dictionary |
Forward Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| z | torch.Tensor | Yes | Input tensor of shape (B, C, H, W) from the encoder |
Outputs
| Name | Type | Description |
|---|---|---|
| z_q | torch.Tensor | Quantized tensor of shape (B, C, H, W), same shape as input, with straight-through gradient |
| out_dict | Dict | Dictionary containing: loss_key (commitment loss), "encodings" (one-hot assignment matrix), "encoding_indices" (integer codebook indices), "perplexity" (codebook usage metric) |
Usage Examples
# Initialize the EMA vector quantizer with 512 codebook entries of dimension 256
quantizer = EMAVectorQuantizer(
n_embed=512,
embedding_dim=256,
beta=0.25,
decay=0.99,
eps=1e-5,
)
# Forward pass with encoder output (batch=4, channels=256, height=16, width=16)
z = encoder(x) # shape: (4, 256, 16, 16)
z_q, loss_dict = quantizer(z)
# Access the commitment loss and perplexity
vq_loss = loss_dict["loss/vq"]
perplexity = loss_dict["perplexity"]
# Retrieve codebook entries by index
indices = loss_dict["encoding_indices"] # flat integer indices