Overview
Utility module providing model loading, tiled image upscaling with Real-ESRGAN, video saving, and progress tracking for the composite Gradio demo's post-processing pipeline.
Description
This module contains a collection of utility functions and a progress bar class that support the video post-processing workflow:
load_torch_file() -- Loads model weights from either safetensors (.safetensors, .sft) or PyTorch checkpoint formats. Handles various state dict structures including "state_dict" and "params_ema" keys, with automatic dtype conversion.
state_dict_prefix_replace() -- Renames keys in a state dict by replacing specified prefixes, useful for loading models saved with module wrappers (e.g., stripping "module." prefix from DataParallel models).
module_size() -- Computes the total memory footprint of a module's parameters in bytes.
get_tiled_scale_steps() -- Calculates the total number of tile processing steps for progress bar initialization.
tiled_scale_multidim() -- Core tiling function that processes images in overlapping tiles with feathered blending. Supports arbitrary spatial dimensions. Applies a feathering mask that linearly ramps from 0 to 1 at tile borders to eliminate seam artifacts.
tiled_scale() -- Convenience wrapper around tiled_scale_multidim that accepts explicit tile_x and tile_y parameters.
load_sd_upscale() -- Loads a Real-ESRGAN super-resolution model via the spandrel ModelLoader, handling prefix normalization.
upscale() -- Applies tiled upscaling to an image tensor with 512x512 tiles and 32-pixel overlap. Estimates memory requirements and manages device placement.
upscale_batch_and_concatenate() -- Applies upscale() to each frame in a video batch and stacks the results.
save_video() -- Exports a list of PIL images or numpy arrays to a timestamped MP4 file using diffusers' export_to_video.
ProgressBar -- Wraps tqdm for tracking progress across tiling and inference operations.
Usage
Used as the shared utility layer for the composite Gradio demo. The upscaling functions are called when users request super-resolution on generated video frames. The model loading utilities support both the upscaler and RIFE models. The progress bar is used throughout the demo for user feedback.
Code Reference
Source Location
Signature
def load_torch_file(ckpt, device=None, dtype=torch.float16) -> dict
def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False) -> dict
def module_size(module) -> int
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap) -> int
@torch.inference_mode()
def tiled_scale_multidim(
samples, function, tile=(64, 64), overlap=8,
upscale_amount=4, out_channels=3, output_device="cpu", pbar=None
) -> torch.Tensor
def tiled_scale(
samples, function, tile_x=64, tile_y=64, overlap=8,
upscale_amount=4, out_channels=3, output_device="cpu", pbar=None
) -> torch.Tensor
def load_sd_upscale(ckpt, inf_device) -> nn.Module
def upscale(upscale_model, tensor, inf_device, output_device="cpu") -> torch.Tensor
def upscale_batch_and_concatenate(
upscale_model, latents, inf_device, output_device="cpu"
) -> torch.Tensor
def save_video(tensor: Union[List[np.ndarray], List[PIL.Image.Image]], fps: int = 8) -> str
class ProgressBar:
def __init__(self, total, desc=None)
def update(self, value)
Import
from utils import (
load_sd_upscale, upscale, upscale_batch_and_concatenate,
tiled_scale, save_video, ProgressBar, load_torch_file,
)
I/O Contract
Inputs (tiled_scale_multidim)
| Name |
Type |
Required |
Description
|
| samples |
torch.Tensor |
Yes |
Input tensor of shape [B, C, ...] where ... are spatial dimensions
|
| function |
Callable |
Yes |
Upscaling function to apply to each tile
|
| tile |
tuple[int, ...] |
No |
Tile size for each spatial dimension. Default: (64, 64)
|
| overlap |
int |
No |
Overlap in pixels between adjacent tiles. Default: 8
|
| upscale_amount |
float |
No |
Upscaling factor. Default: 4
|
| out_channels |
int |
No |
Number of output channels. Default: 3
|
| output_device |
str |
No |
Device for the output tensor. Default: "cpu"
|
| pbar |
ProgressBar |
No |
Progress bar instance for tracking. Default: None
|
Outputs (tiled_scale_multidim)
| Name |
Type |
Description
|
| output |
torch.Tensor |
Upscaled tensor of shape [B, out_channels, ...] where spatial dims are scaled by upscale_amount
|
Inputs (upscale)
| Name |
Type |
Required |
Description
|
| upscale_model |
nn.Module |
Yes |
Loaded Real-ESRGAN model (via spandrel)
|
| tensor |
torch.Tensor |
Yes |
Input image frames of shape [F, C, H, W]
|
| inf_device |
str |
Yes |
Device for inference (e.g., "cuda")
|
| output_device |
str |
No |
Device for output tensor. Default: "cpu"
|
Outputs (upscale)
| Name |
Type |
Description
|
| result |
torch.Tensor |
Upscaled image tensor of shape [F, C, H*scale, W*scale]
|
Usage Examples
from utils import load_sd_upscale, upscale_batch_and_concatenate
# Load Real-ESRGAN upscaler
upscale_model = load_sd_upscale("RealESRGAN_x4.pth", inf_device="cuda")
# Upscale a batch of video frames
# latents shape: [batch, frames, channels, height, width]
upscaled = upscale_batch_and_concatenate(
upscale_model, video_frames, inf_device="cuda", output_device="cpu"
)
# Save to MP4
from utils import save_video
video_path = save_video(pil_frames, fps=16)
Related Pages