Implementation:Zai org CogVideo FSQ
| Knowledge Sources | |
|---|---|
| Domains | Video_Generation, Quantization, Autoencoding |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
FSQ is a PyTorch module that implements Finite Scalar Quantization, a codebook-free alternative to VQ-VAE that discretizes each latent dimension independently into a fixed number of levels.
Description
The FSQ class provides a quantization layer for autoencoder bottlenecks that avoids the need for a learned codebook entirely. Instead of maintaining explicit embedding vectors like traditional vector quantization, FSQ bounds each dimension of the latent representation to a fixed number of discrete levels using a tanh-based bounding function, then applies rounding with a straight-through estimator (STE) to preserve gradient flow during backpropagation.
The total codebook size is implicitly determined as the product of all per-dimension levels. For example, levels=[8, 6, 5] yields an implicit codebook of 8 * 6 * 5 = 240 entries. The module supports optional linear projections when the input feature dimension differs from the codebook dimension, multiple codebooks via the num_codebooks parameter, and automatic reshaping for image and video tensors (4D and 5D inputs).
Index computation is performed via a cumulative-product basis, enabling efficient bidirectional conversion between quantized codes and flat codebook indices through the codes_to_indices and indices_to_codes methods.
Usage
Use FSQ when you need a discrete bottleneck for a variational autoencoder but want to avoid codebook collapse, auxiliary commitment losses, and the complexity of learned codebook management. It is particularly suited for video autoencoders where simple, stable quantization is preferred.
Code Reference
Source Location
- Repository: Zai_org_CogVideo
- File:
sat/sgm/modules/autoencoding/regularizers/finite_scalar_quantization.py
Signature
class FSQ(Module):
def __init__(
self,
levels: List[int],
dim: Optional[int] = None,
num_codebooks=1,
keep_num_codebooks_dim: Optional[bool] = None,
scale: Optional[float] = None,
):
Import
from sat.sgm.modules.autoencoding.regularizers.finite_scalar_quantization import FSQ
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
levels |
List[int] |
Yes | Number of discrete quantization levels per latent dimension. The product of all levels defines the implicit codebook size. |
dim |
Optional[int] |
No | Input feature dimension. Defaults to len(levels) * num_codebooks. When different from the effective codebook dimension, linear projections are added.
|
num_codebooks |
int |
No | Number of independent codebooks to use (default: 1). |
keep_num_codebooks_dim |
Optional[bool] |
No | Whether to preserve the codebook dimension in the output indices. Defaults to True when num_codebooks > 1.
|
scale |
Optional[float] |
No | Optional scaling factor for the quantization. |
Forward Inputs
| Name | Type | Required | Description |
|---|---|---|---|
z |
Tensor |
Yes | Input tensor of shape (batch, dim, ...) for images/video or (batch, seq, dim) for sequences.
|
Outputs
| Name | Type | Description |
|---|---|---|
out |
Tensor |
Quantized output tensor with the same shape as the input z.
|
indices |
Tensor |
Flat codebook indices corresponding to the quantized codes. |
Key Methods
bound(z, eps=1e-3)
Bounds the input tensor to valid quantization ranges using tanh with per-dimension offsets to handle both even and odd level counts.
quantize(z)
Applies bounding followed by rounding with a straight-through estimator, then normalizes the result to [-1, 1].
codes_to_indices(zhat)
Converts quantized code vectors to flat integer indices using the cumulative-product basis.
indices_to_codes(indices, project_out=True)
Converts flat codebook indices back to quantized code vectors, optionally applying the output projection.
Usage Examples
from sat.sgm.modules.autoencoding.regularizers.finite_scalar_quantization import FSQ
# Create FSQ with 5 dimensions, each with different level counts
# Implicit codebook size = 8 * 5 * 5 * 5 * 5 = 5000
fsq = FSQ(levels=[8, 5, 5, 5, 5])
# Quantize a batch of image feature maps (B, C, H, W)
z = torch.randn(2, 5, 16, 16)
quantized, indices = fsq(z)
# quantized.shape: (2, 5, 16, 16)
# indices.shape: (2, 16, 16)
# Convert indices back to codes
reconstructed = fsq.indices_to_codes(indices)