Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Zai org CogVideo Upscale Utils

From Leeroopedia


Knowledge Sources
Domains Video_Generation, Image_Super_Resolution
Last Updated 2026-02-10 00:00 GMT

Overview

Utility module providing model loading, tiled image upscaling with Real-ESRGAN, video saving, and progress tracking for the composite Gradio demo's post-processing pipeline.

Description

This module contains a collection of utility functions and a progress bar class that support the video post-processing workflow:

  • load_torch_file() -- Loads model weights from either safetensors (.safetensors, .sft) or PyTorch checkpoint formats. Handles various state dict structures including "state_dict" and "params_ema" keys, with automatic dtype conversion.
  • state_dict_prefix_replace() -- Renames keys in a state dict by replacing specified prefixes, useful for loading models saved with module wrappers (e.g., stripping "module." prefix from DataParallel models).
  • module_size() -- Computes the total memory footprint of a module's parameters in bytes.
  • get_tiled_scale_steps() -- Calculates the total number of tile processing steps for progress bar initialization.
  • tiled_scale_multidim() -- Core tiling function that processes images in overlapping tiles with feathered blending. Supports arbitrary spatial dimensions. Applies a feathering mask that linearly ramps from 0 to 1 at tile borders to eliminate seam artifacts.
  • tiled_scale() -- Convenience wrapper around tiled_scale_multidim that accepts explicit tile_x and tile_y parameters.
  • load_sd_upscale() -- Loads a Real-ESRGAN super-resolution model via the spandrel ModelLoader, handling prefix normalization.
  • upscale() -- Applies tiled upscaling to an image tensor with 512x512 tiles and 32-pixel overlap. Estimates memory requirements and manages device placement.
  • upscale_batch_and_concatenate() -- Applies upscale() to each frame in a video batch and stacks the results.
  • save_video() -- Exports a list of PIL images or numpy arrays to a timestamped MP4 file using diffusers' export_to_video.
  • ProgressBar -- Wraps tqdm for tracking progress across tiling and inference operations.

Usage

Used as the shared utility layer for the composite Gradio demo. The upscaling functions are called when users request super-resolution on generated video frames. The model loading utilities support both the upscaler and RIFE models. The progress bar is used throughout the demo for user feedback.

Code Reference

Source Location

Signature

def load_torch_file(ckpt, device=None, dtype=torch.float16) -> dict

def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False) -> dict

def module_size(module) -> int

def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap) -> int

@torch.inference_mode()
def tiled_scale_multidim(
    samples, function, tile=(64, 64), overlap=8,
    upscale_amount=4, out_channels=3, output_device="cpu", pbar=None
) -> torch.Tensor

def tiled_scale(
    samples, function, tile_x=64, tile_y=64, overlap=8,
    upscale_amount=4, out_channels=3, output_device="cpu", pbar=None
) -> torch.Tensor

def load_sd_upscale(ckpt, inf_device) -> nn.Module

def upscale(upscale_model, tensor, inf_device, output_device="cpu") -> torch.Tensor

def upscale_batch_and_concatenate(
    upscale_model, latents, inf_device, output_device="cpu"
) -> torch.Tensor

def save_video(tensor: Union[List[np.ndarray], List[PIL.Image.Image]], fps: int = 8) -> str

class ProgressBar:
    def __init__(self, total, desc=None)
    def update(self, value)

Import

from utils import (
    load_sd_upscale, upscale, upscale_batch_and_concatenate,
    tiled_scale, save_video, ProgressBar, load_torch_file,
)

I/O Contract

Inputs (tiled_scale_multidim)

Name Type Required Description
samples torch.Tensor Yes Input tensor of shape [B, C, ...] where ... are spatial dimensions
function Callable Yes Upscaling function to apply to each tile
tile tuple[int, ...] No Tile size for each spatial dimension. Default: (64, 64)
overlap int No Overlap in pixels between adjacent tiles. Default: 8
upscale_amount float No Upscaling factor. Default: 4
out_channels int No Number of output channels. Default: 3
output_device str No Device for the output tensor. Default: "cpu"
pbar ProgressBar No Progress bar instance for tracking. Default: None

Outputs (tiled_scale_multidim)

Name Type Description
output torch.Tensor Upscaled tensor of shape [B, out_channels, ...] where spatial dims are scaled by upscale_amount

Inputs (upscale)

Name Type Required Description
upscale_model nn.Module Yes Loaded Real-ESRGAN model (via spandrel)
tensor torch.Tensor Yes Input image frames of shape [F, C, H, W]
inf_device str Yes Device for inference (e.g., "cuda")
output_device str No Device for output tensor. Default: "cpu"

Outputs (upscale)

Name Type Description
result torch.Tensor Upscaled image tensor of shape [F, C, H*scale, W*scale]

Usage Examples

from utils import load_sd_upscale, upscale_batch_and_concatenate

# Load Real-ESRGAN upscaler
upscale_model = load_sd_upscale("RealESRGAN_x4.pth", inf_device="cuda")

# Upscale a batch of video frames
# latents shape: [batch, frames, channels, height, width]
upscaled = upscale_batch_and_concatenate(
    upscale_model, video_frames, inf_device="cuda", output_device="cpu"
)

# Save to MP4
from utils import save_video
video_path = save_video(pil_frames, fps=16)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment