Implementation:PeterL1n BackgroundMattingV2 MattingRefine TorchScriptWrapper
| Knowledge Sources | |
|---|---|
| Domains | Model_Deployment, Optimization |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Concrete tool for TorchScript export of the matting model with runtime-configurable parameters provided by export_torchscript.py.
Description
MattingRefine_TorchScriptWrapper wraps MattingRefine to hoist configurable attributes (backbone_scale, refine_mode, refine_sample_pixels, refine_threshold, refine_prevent_oversampling) from nested sub-modules to the top level. This is required because TorchScript does not support changing attributes on nested modules after loading. The wrapper's forward() copies the top-level attributes back to the inner model before each inference call.
The export pipeline: create wrapper → load state dict → disable gradients → optional float16 conversion → torch.jit.script → save.
Usage
Use when deploying the matting model via TorchScript for C++ or mobile inference. After loading the saved TorchScript model, users can adjust refinement parameters without re-exporting.
Code Reference
Source Location
- Repository: BackgroundMattingV2
- File: export_torchscript.py
- Lines: 33-83
Signature
class MattingRefine_TorchScriptWrapper(nn.Module):
"""
Wraps MattingRefine with hoisted configurable attributes for TorchScript.
"""
def __init__(self, *args, **kwargs):
"""
Args: Same as MattingRefine (backbone, etc.)
Hoisted Attributes:
backbone_scale: float
refine_mode: str
refine_sample_pixels: int
refine_threshold: float
refine_prevent_oversampling: bool
"""
def forward(
self,
src: Tensor, # (B, 3, H, W) source, RGB, 0-1
bgr: Tensor # (B, 3, H, W) background, RGB, 0-1
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Copies hoisted attributes to inner model, then forwards."""
def load_state_dict(self, *args, **kwargs):
"""Delegates to inner model's load_state_dict."""
Import
# Defined in export_torchscript.py
from model import MattingRefine # Used internally
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| backbone | str | Yes | Encoder backbone ('resnet50', 'resnet101', 'mobilenetv2') |
| model-checkpoint | str | Yes | Path to trained .pth checkpoint |
| precision | str | No | 'float32' or 'float16' (default 'float32') |
| output | str | Yes | Output TorchScript .pth file path |
Outputs
| Name | Type | Description |
|---|---|---|
| .pth file | File | TorchScript serialized model with configurable attributes |
Usage Examples
Export TorchScript
import torch
from export_torchscript import MattingRefine_TorchScriptWrapper
# Create and load
model = MattingRefine_TorchScriptWrapper('resnet50').eval()
model.load_state_dict(torch.load('checkpoint.pth', map_location='cpu'))
# Disable gradients and export
for p in model.parameters():
p.requires_grad = False
scripted = torch.jit.script(model)
scripted.save('matting_torchscript.pth')
Load and Configure
# Load exported model
model = torch.jit.load('matting_torchscript.pth')
# Runtime configuration (no re-export needed)
model.backbone_scale = 0.25
model.refine_mode = 'thresholding'
model.refine_threshold = 0.1
# Inference
pha, fgr = model(src, bgr)[:2]