Implementation:Zai org CogVideo RIFE Loss Functions
| Knowledge Sources | |
|---|---|
| Domains | Video_Generation, Loss_Functions, Optical_Flow |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
A collection of loss function modules used for training RIFE frame interpolation models, including endpoint error (EPE), ternary census loss, Sobel edge loss, and VGG perceptual loss.
Description
This module defines five loss-related classes used in the RIFE training pipeline:
- EPE (Endpoint Error): Computes the L2 distance between predicted and ground-truth optical flow, masked by a loss region. The loss map is computed as the square root of the sum of squared per-channel differences, multiplied by a binary mask.
- Ternary: Implements the census transform loss for structural comparison. It extracts 7x7 local patches via convolution with an identity kernel, computes normalized differences from the center pixel (using the transformation
transf / sqrt(0.81 + transf^2)), converts images to grayscale, and measures a robust Hamming distance between transformed patches. A valid mask excludes border pixels.
- SOBEL: Applies horizontal and vertical Sobel edge-detection kernels to both prediction and ground truth, computing L1 differences on the edge maps. The kernels
[[1,0,-1],[2,0,-2],[1,0,-1]]and its transpose detect horizontal and vertical edges respectively.
- MeanShift: A 1x1 convolution layer that normalizes or denormalizes images using ImageNet mean and standard deviation statistics. Used as a preprocessing step for VGG feature extraction.
- VGGPerceptualLoss: Extracts features from 5 layers of a pretrained VGG-19 network (indices 2, 7, 12, 21, 30) and computes weighted L1 differences. Layer weights [1/2.6, 1/4.8, 1/3.7, 1/5.6, 10/1.5] balance contributions from early texture features and deeper semantic features.
Usage
Use these loss functions when training RIFE models. EPE supervises flow accuracy, Ternary and SOBEL enforce structural and edge consistency, and VGGPerceptualLoss ensures perceptual quality. They are instantiated by the Model class in RIFE.py.
Code Reference
Source Location
- Repository: Zai_org_CogVideo
- File: inference/gradio_composite_demo/rife/loss.py
Signature
class EPE(nn.Module):
def __init__(self)
def forward(self, flow: torch.Tensor, gt: torch.Tensor, loss_mask: torch.Tensor) -> torch.Tensor
class Ternary(nn.Module):
def __init__(self)
def transform(self, img: torch.Tensor) -> torch.Tensor
def rgb2gray(self, rgb: torch.Tensor) -> torch.Tensor
def hamming(self, t1: torch.Tensor, t2: torch.Tensor) -> torch.Tensor
def valid_mask(self, t: torch.Tensor, padding: int) -> torch.Tensor
def forward(self, img0: torch.Tensor, img1: torch.Tensor) -> torch.Tensor
class SOBEL(nn.Module):
def __init__(self)
def forward(self, pred: torch.Tensor, gt: torch.Tensor) -> torch.Tensor
class MeanShift(nn.Conv2d):
def __init__(self, data_mean: list, data_std: list, data_range=1, norm=True)
class VGGPerceptualLoss(torch.nn.Module):
def __init__(self, rank=0)
def forward(self, X: torch.Tensor, Y: torch.Tensor, indices=None) -> torch.Tensor
Import
from inference.gradio_composite_demo.rife.loss import EPE, Ternary, SOBEL, VGGPerceptualLoss, MeanShift
I/O Contract
Inputs
EPE.forward:
| Name | Type | Required | Description |
|---|---|---|---|
| flow | torch.Tensor | Yes | Predicted optical flow tensor |
| gt | torch.Tensor | Yes | Ground truth optical flow tensor (detached internally) |
| loss_mask | torch.Tensor | Yes | Binary mask indicating valid flow regions |
Ternary.forward:
| Name | Type | Required | Description |
|---|---|---|---|
| img0 | torch.Tensor | Yes | First RGB image tensor (B, 3, H, W) |
| img1 | torch.Tensor | Yes | Second RGB image tensor (B, 3, H, W) |
SOBEL.forward:
| Name | Type | Required | Description |
|---|---|---|---|
| pred | torch.Tensor | Yes | Predicted image tensor (B, C, H, W) |
| gt | torch.Tensor | Yes | Ground truth image tensor (B, C, H, W) |
VGGPerceptualLoss.forward:
| Name | Type | Required | Description |
|---|---|---|---|
| X | torch.Tensor | Yes | Predicted image tensor (B, 3, H, W) |
| Y | torch.Tensor | Yes | Target image tensor (B, 3, H, W) |
| indices | list | No | VGG layer indices for feature extraction (overridden to [2, 7, 12, 21, 30] internally) |
Outputs
| Name | Type | Description |
|---|---|---|
| EPE output | torch.Tensor | Per-pixel L2 flow error map multiplied by loss mask |
| Ternary output | torch.Tensor | Census transform distance map with border mask, shape (B, 1, H, W) |
| SOBEL output | torch.Tensor | Sum of horizontal and vertical edge L1 differences, shape (B*C, 1, H, W) |
| VGGPerceptualLoss output | torch.Tensor | Scalar weighted sum of L1 feature differences across 5 VGG-19 layers |
Usage Examples
import torch
from inference.gradio_composite_demo.rife.loss import EPE, Ternary, SOBEL, VGGPerceptualLoss
# Endpoint Error for flow supervision
epe = EPE()
pred_flow = torch.randn(2, 2, 256, 256).cuda()
gt_flow = torch.randn(2, 2, 256, 256).cuda()
mask = torch.ones(2, 1, 256, 256).cuda()
flow_loss = epe(pred_flow, gt_flow, mask)
# Ternary census loss for structural similarity
ternary = Ternary()
img0 = torch.randn(2, 3, 256, 256).cuda()
img1 = torch.randn(2, 3, 256, 256).cuda()
census_loss = ternary(img0, img1)
# Sobel edge loss
sobel = SOBEL()
edge_loss = sobel(img0, img1)
# VGG perceptual loss
vgg_loss_fn = VGGPerceptualLoss(rank=0)
perceptual_loss = vgg_loss_fn(img0, img1)