Implementation:Zai org CogVideo LatentLPIPS
| Knowledge Sources | |
|---|---|
| Domains | Video_Generation, Perceptual_Loss |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Implements a composite loss module that combines latent-space L2 loss with LPIPS perceptual loss computed in decoded image space, enabling training in latent space while optimizing for perceptual quality.
Description
The LatentLPIPS class bridges latent-space training objectives with pixel-space perceptual quality metrics. It computes a weighted combination of:
- Latent L2 loss: Direct mean squared error between predicted and target latent representations, weighted by
latent_weight. - Decoded perceptual loss: The module maintains a frozen decoder that converts latent predictions and targets back to image space. LPIPS perceptual distance is computed between the decoded images, weighted by
perceptual_weight. - Input perceptual loss (optional): When
perceptual_weight_on_inputs > 0, an additional LPIPS term compares decoded predictions against original input images. This supports scenarios where the input image resolution differs from the reconstruction resolution, with configurable bicubic rescaling in either direction (scale_input_to_tgt_sizeorscale_tgt_to_input_size).
The decoder is initialized from a config and has its encoder component removed to save memory, since only decoding is needed for loss computation.
Usage
Used as a training loss for latent diffusion or two-stage autoencoder models where the primary training occurs in latent space but perceptual quality in pixel space must also be optimized. Particularly useful when fine-tuning a latent-space model to better preserve perceptual features visible in decoded outputs.
Code Reference
Source Location
- Repository: Zai_org_CogVideo
- File: sat/sgm/modules/autoencoding/losses/lpips.py
Signature
class LatentLPIPS(nn.Module):
def __init__(
self,
decoder_config,
perceptual_weight=1.0,
latent_weight=1.0,
scale_input_to_tgt_size=False,
scale_tgt_to_input_size=False,
perceptual_weight_on_inputs=0.0,
)
def init_decoder(self, config)
def forward(
self,
latent_inputs,
latent_predictions,
image_inputs,
split="train",
) -> tuple[torch.Tensor, dict]
Import
from sat.sgm.modules.autoencoding.losses.lpips import LatentLPIPS
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| latent_inputs | torch.Tensor |
Yes | Target latent representations [B, C, H, W]
|
| latent_predictions | torch.Tensor |
Yes | Predicted latent representations [B, C, H, W]
|
| image_inputs | torch.Tensor |
Yes | Original input images [B, 3, H', W'] (used when perceptual_weight_on_inputs > 0)
|
| split | str |
No | Log key prefix for loss tracking. Default: "train" |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | torch.Tensor |
Scalar weighted combination of latent L2 and perceptual losses |
| log | dict |
Dictionary containing detached loss components: {split}/latent_l2_loss, {split}/perceptual_loss, and optionally {split}/perceptual_loss_on_inputs
|
Usage Examples
from sat.sgm.modules.autoencoding.losses.lpips import LatentLPIPS
# Initialize with decoder config and loss weights
loss_fn = LatentLPIPS(
decoder_config={"target": "my_decoder.Decoder", "params": {...}},
perceptual_weight=1.0,
latent_weight=0.5,
perceptual_weight_on_inputs=0.1,
scale_tgt_to_input_size=True,
)
# Compute composite loss
loss, log_dict = loss_fn(
latent_inputs=z_target,
latent_predictions=z_pred,
image_inputs=original_images,
split="train",
)
loss.backward()