Overview
Defines the Contextnet and Unet refinement modules that post-process warped frames from the IFNet optical-flow estimator to produce high-quality residual corrections for frame interpolation.
Description
This module contains three neural network classes and two helper factory functions that together form the refinement stage of the RIFE (Real-Time Intermediate Flow Estimation) pipeline:
conv() -- Factory function that returns a sequential block of nn.Conv2d followed by nn.PReLU activation.
deconv() -- Factory function that returns a sequential block of nn.ConvTranspose2d followed by nn.PReLU activation, used for upsampling in the decoder.
Conv2 -- A double-convolution block with configurable stride. The first convolution applies the stride for downsampling, and the second convolution refines features at the new resolution.
Contextnet -- A 4-level feature encoder that progressively downsamples an input image using Conv2 blocks (base channel count c=16). At each scale, the optical flow field is bilinearly downsampled and halved, then used to warp the features via the warp() function. The output is a list of four warped feature maps at decreasing spatial resolutions.
Unet -- A 4-level encoder-decoder network. The encoder takes a 17-channel concatenation of original frames, warped frames, blending mask, and flow field. At each encoder stage, contextnet features from both source frames are fused via channel-wise concatenation. The decoder uses transposed convolutions with skip connections, producing a 3-channel sigmoid-activated residual image.
Usage
Used internally by the RIFE IFNet and IFNet_m models. The Contextnet extracts multi-scale context features from source frames guided by estimated optical flow, and the Unet combines all inputs to produce a refined interpolated frame that corrects warping artifacts such as occlusions and blending seams.
Code Reference
Source Location
- Repository: Zai_org_CogVideo
- File: inference/gradio_composite_demo/rife/refine.py
Signature
# Factory functions
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1)
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1)
# Conv2 block
class Conv2(nn.Module):
def __init__(self, in_planes, out_planes, stride=2)
def forward(self, x)
# Contextnet encoder
class Contextnet(nn.Module):
def __init__(self)
def forward(self, x, flow) -> list # returns [f1, f2, f3, f4]
# Unet refinement decoder
class Unet(nn.Module):
def __init__(self)
def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
Import
from rife.refine import Contextnet, Unet, Conv2, conv, deconv
I/O Contract
Inputs (Contextnet.forward)
| Name |
Type |
Required |
Description
|
| x |
torch.Tensor |
Yes |
Input image tensor of shape [B, 3, H, W]
|
| flow |
torch.Tensor |
Yes |
Optical flow field of shape [B, 2, H, W], progressively downsampled at each level
|
Outputs (Contextnet.forward)
| Name |
Type |
Description
|
| features |
list[torch.Tensor] |
Four warped feature maps at scales 1/2, 1/4, 1/8, 1/16 of the input resolution with channel counts [16, 32, 64, 128]
|
Inputs (Unet.forward)
| Name |
Type |
Required |
Description
|
| img0 |
torch.Tensor |
Yes |
First source image [B, 3, H, W]
|
| img1 |
torch.Tensor |
Yes |
Second source image [B, 3, H, W]
|
| warped_img0 |
torch.Tensor |
Yes |
First image warped by estimated flow [B, 3, H, W]
|
| warped_img1 |
torch.Tensor |
Yes |
Second image warped by estimated flow [B, 3, H, W]
|
| mask |
torch.Tensor |
Yes |
Blending mask [B, 1, H, W]
|
| flow |
torch.Tensor |
Yes |
Concatenated bi-directional flow [B, 4, H, W]
|
| c0 |
list[torch.Tensor] |
Yes |
Contextnet features for frame 0 (4 scales)
|
| c1 |
list[torch.Tensor] |
Yes |
Contextnet features for frame 1 (4 scales)
|
Outputs (Unet.forward)
| Name |
Type |
Description
|
| residual |
torch.Tensor |
Sigmoid-activated residual image [B, 3, H, W] in range [0, 1]
|
Usage Examples
import torch
from rife.refine import Contextnet, Unet
# Initialize refinement modules
contextnet = Contextnet()
unet = Unet()
# Extract multi-scale context features from source frames
flow_01 = estimated_flow[:, :2] # flow from frame 0 to interpolation point
flow_10 = estimated_flow[:, 2:] # flow from frame 1 to interpolation point
c0 = contextnet(img0, flow_01)
c1 = contextnet(img1, flow_10)
# Produce refined interpolated frame
residual = unet(img0, img1, warped_img0, warped_img1, mask, estimated_flow, c0, c1)
Related Pages