Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Zai org CogVideo RIFE Refinement

From Leeroopedia
Revision as of 17:09, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Zai_org_CogVideo_RIFE_Refinement.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Video_Generation, Optical_Flow
Last Updated 2026-02-10 00:00 GMT

Overview

Defines the Contextnet and Unet refinement modules that post-process warped frames from the IFNet optical-flow estimator to produce high-quality residual corrections for frame interpolation.

Description

This module contains three neural network classes and two helper factory functions that together form the refinement stage of the RIFE (Real-Time Intermediate Flow Estimation) pipeline:

  • conv() -- Factory function that returns a sequential block of nn.Conv2d followed by nn.PReLU activation.
  • deconv() -- Factory function that returns a sequential block of nn.ConvTranspose2d followed by nn.PReLU activation, used for upsampling in the decoder.
  • Conv2 -- A double-convolution block with configurable stride. The first convolution applies the stride for downsampling, and the second convolution refines features at the new resolution.
  • Contextnet -- A 4-level feature encoder that progressively downsamples an input image using Conv2 blocks (base channel count c=16). At each scale, the optical flow field is bilinearly downsampled and halved, then used to warp the features via the warp() function. The output is a list of four warped feature maps at decreasing spatial resolutions.
  • Unet -- A 4-level encoder-decoder network. The encoder takes a 17-channel concatenation of original frames, warped frames, blending mask, and flow field. At each encoder stage, contextnet features from both source frames are fused via channel-wise concatenation. The decoder uses transposed convolutions with skip connections, producing a 3-channel sigmoid-activated residual image.

Usage

Used internally by the RIFE IFNet and IFNet_m models. The Contextnet extracts multi-scale context features from source frames guided by estimated optical flow, and the Unet combines all inputs to produce a refined interpolated frame that corrects warping artifacts such as occlusions and blending seams.

Code Reference

Source Location

  • Repository: Zai_org_CogVideo
  • File: inference/gradio_composite_demo/rife/refine.py

Signature

# Factory functions
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1)
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1)

# Conv2 block
class Conv2(nn.Module):
    def __init__(self, in_planes, out_planes, stride=2)
    def forward(self, x)

# Contextnet encoder
class Contextnet(nn.Module):
    def __init__(self)
    def forward(self, x, flow) -> list  # returns [f1, f2, f3, f4]

# Unet refinement decoder
class Unet(nn.Module):
    def __init__(self)
    def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)

Import

from rife.refine import Contextnet, Unet, Conv2, conv, deconv

I/O Contract

Inputs (Contextnet.forward)

Name Type Required Description
x torch.Tensor Yes Input image tensor of shape [B, 3, H, W]
flow torch.Tensor Yes Optical flow field of shape [B, 2, H, W], progressively downsampled at each level

Outputs (Contextnet.forward)

Name Type Description
features list[torch.Tensor] Four warped feature maps at scales 1/2, 1/4, 1/8, 1/16 of the input resolution with channel counts [16, 32, 64, 128]

Inputs (Unet.forward)

Name Type Required Description
img0 torch.Tensor Yes First source image [B, 3, H, W]
img1 torch.Tensor Yes Second source image [B, 3, H, W]
warped_img0 torch.Tensor Yes First image warped by estimated flow [B, 3, H, W]
warped_img1 torch.Tensor Yes Second image warped by estimated flow [B, 3, H, W]
mask torch.Tensor Yes Blending mask [B, 1, H, W]
flow torch.Tensor Yes Concatenated bi-directional flow [B, 4, H, W]
c0 list[torch.Tensor] Yes Contextnet features for frame 0 (4 scales)
c1 list[torch.Tensor] Yes Contextnet features for frame 1 (4 scales)

Outputs (Unet.forward)

Name Type Description
residual torch.Tensor Sigmoid-activated residual image [B, 3, H, W] in range [0, 1]

Usage Examples

import torch
from rife.refine import Contextnet, Unet

# Initialize refinement modules
contextnet = Contextnet()
unet = Unet()

# Extract multi-scale context features from source frames
flow_01 = estimated_flow[:, :2]  # flow from frame 0 to interpolation point
flow_10 = estimated_flow[:, 2:]  # flow from frame 1 to interpolation point
c0 = contextnet(img0, flow_01)
c1 = contextnet(img1, flow_10)

# Produce refined interpolated frame
residual = unet(img0, img1, warped_img0, warped_img1, mask, estimated_flow, c0, c1)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment