Implementation:PeterL1n BackgroundMattingV2 Torch checkpoint ops
| Knowledge Sources | |
|---|---|
| Domains | Training, Model_Persistence |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Concrete wrapper for PyTorch's torch.save and torch.load checkpoint operations as used in BackgroundMattingV2's training and inference scripts.
Description
BackgroundMattingV2 uses standard PyTorch serialization for checkpoint management. During training, torch.save(model.state_dict(), path) saves the model weights at configurable intervals (every N steps and at epoch boundaries). For loading, model.load_state_dict(torch.load(path, map_location=device), strict=False) restores weights, with strict=False allowing partial loading when model architectures differ slightly (e.g., loading MattingBase weights into MattingRefine). Validation is performed at intervals using a valid() function that computes loss on a held-out subset and logs to TensorBoard.
Usage
Use for saving checkpoints during training and loading them for inference, export, or training resumption. The checkpoint interval is controlled by --checkpoint-interval (default 5000 for base, 2000 for refine).
Code Reference
Source Location
- Repository: BackgroundMattingV2
- File: train_base.py (Lines 210-213, 239-258), train_refine.py (Lines 237-241, 278-297)
Signature
# Checkpoint saving (train_base.py:L210-213)
torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth')
# Checkpoint loading (inference_images.py:L85)
model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False)
# Validation function (train_base.py:L239-258)
def valid(model, dataloader, writer, step):
model.eval()
loss_total = 0
# ... compute validation loss ...
writer.add_scalar('valid_loss', loss_total / len(dataloader), step)
model.train()
Import
import torch
from torch.utils.tensorboard import SummaryWriter
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model.state_dict() | Dict[str, Tensor] | Yes | Model parameter dictionary to save |
| path | str | Yes | Filesystem path for checkpoint file (.pth) |
| map_location | str/device | No | Device mapping for loading (default: current device) |
| strict | bool | No | Whether to require exact key matching (default: False in this repo) |
Outputs
| Name | Type | Description |
|---|---|---|
| .pth file | File | Serialized model state dict at checkpoint/{model_name}/epoch-{N}-iter-{step}.pth |
| TensorBoard log | Scalar | valid_loss logged to log/{model_name}/ |
Usage Examples
Save During Training
import torch
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(f'log/{model_name}')
# Save at interval
if step % checkpoint_interval == 0:
torch.save(model.state_dict(), f'checkpoint/{model_name}/epoch-{epoch}-iter-{step}.pth')
# Save at epoch end
torch.save(model.state_dict(), f'checkpoint/{model_name}/epoch-{epoch}.pth')
Load for Inference
import torch
from model import MattingRefine
model = MattingRefine(backbone='resnet50')
model.load_state_dict(
torch.load('checkpoint/model/epoch-10.pth', map_location='cuda'),
strict=False
)
model.eval()