Implementation:Zai org CogVideo RIFE Model
| Knowledge Sources | |
|---|---|
| Domains | Video_Generation, Frame_Interpolation, Optical_Flow |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
The RIFE Model class wraps IFNet into a high-level interface that provides training, inference, model checkpoint management, and evaluation for Real-Time Intermediate Flow Estimation frame interpolation.
Description
The Model class serves as the main orchestrator for the RIFE frame interpolation pipeline. It initializes the flow estimation network (standard IFNet for fixed-timestep or IFNet_m for arbitrary-timestep interpolation), an AdamW optimizer with weight decay of 1e-3, and three loss modules: EPE (endpoint error for flow supervision), LapLoss (Laplacian pyramid loss for multi-scale reconstruction), and SOBEL (edge-aware loss). It supports distributed training via PyTorch's DistributedDataParallel (DDP).
The inference method concatenates two input frames, runs them through the flow network at configurable scales, and returns the merged interpolated frame. It supports optional test-time augmentation (TTA) by running inference on horizontally and vertically flipped inputs and averaging the results with the unflipped output.
The update method handles the full training step: it passes concatenated input frames and ground truth through the flow network, computes a combined loss of Laplacian pyramid loss on the student output, Laplacian pyramid loss on the teacher output, and weighted distillation loss (0.01 weight), then backpropagates. Checkpoint loading handles DDP module prefix stripping via a convert helper.
Usage
Use the RIFE Model as the main entry point for frame interpolation tasks, both for training new models and running inference. It is the primary model class used by the Gradio composite demo.
Code Reference
Source Location
- Repository: Zai_org_CogVideo
- File: inference/gradio_composite_demo/rife/RIFE.py
Signature
class Model:
def __init__(self, local_rank=-1, arbitrary=False)
def train(self)
def eval(self)
def device(self)
def load_model(self, path, rank=0)
def save_model(self, path, rank=0)
def inference(self, img0, img1, scale=1, scale_list=[4, 2, 1], TTA=False, timestep=0.5)
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None)
Import
from inference.gradio_composite_demo.rife.RIFE import Model
I/O Contract
Inputs
__init__:
| Name | Type | Required | Description |
|---|---|---|---|
| local_rank | int | No | GPU rank for DDP, default -1 (no DDP) |
| arbitrary | bool | No | If True, use IFNet_m for arbitrary timestep interpolation; default False |
inference:
| Name | Type | Required | Description |
|---|---|---|---|
| img0 | torch.Tensor | Yes | First input frame, shape (B, 3, H, W) |
| img1 | torch.Tensor | Yes | Second input frame, shape (B, 3, H, W) |
| scale | int | No | Global scale divisor applied to scale_list, default 1 |
| scale_list | list[int] | No | Multi-scale factors for flow estimation, default [4, 2, 1] |
| TTA | bool | No | Enable test-time augmentation via flipping, default False |
| timestep | float | No | Interpolation position between frames, default 0.5 |
update:
| Name | Type | Required | Description |
|---|---|---|---|
| imgs | torch.Tensor | Yes | Concatenated input frames (B, 6, H, W) |
| gt | torch.Tensor | Yes | Ground truth intermediate frame (B, 3, H, W) |
| learning_rate | float | No | Learning rate for this step, default 0 |
| mul | int | No | Multiplier (unused in current implementation), default 1 |
| training | bool | No | Training mode flag, default True |
| flow_gt | torch.Tensor | No | Ground truth optical flow (unused in current implementation) |
Outputs
inference:
| Name | Type | Description |
|---|---|---|
| result | torch.Tensor | Interpolated frame of shape (B, 3, H, W), values in [0, 1] |
update:
| Name | Type | Description |
|---|---|---|
| merged | torch.Tensor | Interpolated frame at finest scale, shape (B, 3, H, W) |
| info | dict | Dictionary containing: merged_tea, mask, mask_tea, flow, flow_tea, loss_l1, loss_tea, loss_distill |
Usage Examples
import torch
from inference.gradio_composite_demo.rife.RIFE import Model
# Initialize and load pretrained weights
model = Model(local_rank=-1, arbitrary=False)
model.load_model("/path/to/rife_weights/", rank=0)
model.eval()
# Interpolate between two frames
img0 = torch.randn(1, 3, 480, 640).cuda()
img1 = torch.randn(1, 3, 480, 640).cuda()
with torch.no_grad():
result = model.inference(img0, img1, scale=1, TTA=False, timestep=0.5)
# result shape: (1, 3, 480, 640)
# With test-time augmentation for higher quality
with torch.no_grad():
result_tta = model.inference(img0, img1, TTA=True)