Implementation:PeterL1n BackgroundMattingV2 Load pretrained deeplabv3 state dict
| Knowledge Sources | |
|---|---|
| Domains | Transfer_Learning, Computer_Vision |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Concrete tool for loading pretrained DeepLabV3 weights into the matting model provided by model/model.py.
Description
The load_pretrained_deeplabv3_state_dict method on the Base class converts and loads pretrained DeepLabV3 state dict weights into the matting model's encoder, ASPP, and decoder modules. It remaps the ASPP key names (classifier.classifier.0 → aspp) and uses load_matched_state_dict from model/utils.py to only load weights with matching keys and shapes, silently skipping mismatched parameters. For MobileNetV2 backbones, it temporarily restructures the backbone attributes to match DeepLabV3's naming convention.
Usage
Call this method once during initial training to initialize from pretrained segmentation weights. Not needed when resuming from a matting checkpoint — use load_state_dict() instead.
Code Reference
Source Location
- Repository: BackgroundMattingV2
- File: model/model.py
- Lines: 37-58
Signature
class Base(nn.Module):
def load_pretrained_deeplabv3_state_dict(
self,
state_dict: Dict[str, Tensor],
print_stats: bool = True
) -> None:
"""
Converts and loads pretrained DeepLabV3 state_dict to match model structure.
Args:
state_dict: Pretrained DeepLabV3 state dict, typically loaded via
torch.load(path)['model_state']
print_stats: Whether to print match statistics (matched/total keys)
"""
Import
from model import MattingBase
# or
from model.model import Base
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| state_dict | Dict[str, Tensor] | Yes | Pretrained DeepLabV3 state dict from torch.load(path)['model_state'] |
| print_stats | bool | No | Print matched/total key counts (default True) |
Outputs
| Name | Type | Description |
|---|---|---|
| (in-place) | None | Model weights are modified in-place; matched keys loaded, unmatched skipped |
Usage Examples
Loading Pretrained DeepLabV3 Weights
import torch
from model import MattingBase
# 1. Create model
model = MattingBase(backbone='resnet50')
# 2. Load pretrained DeepLabV3 checkpoint
deeplabv3_checkpoint = torch.load('best_deeplabv3_resnet50_voc_os16.pth')
# 3. Transfer matching weights
model.load_pretrained_deeplabv3_state_dict(deeplabv3_checkpoint['model_state'])
# Prints: "Loaded state_dict: 267/308 matched"