Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:PeterL1n BackgroundMattingV2 Torch checkpoint ops

From Leeroopedia


Knowledge Sources
Domains Training, Model_Persistence
Last Updated 2026-02-09 00:00 GMT

Overview

Concrete wrapper for PyTorch's torch.save and torch.load checkpoint operations as used in BackgroundMattingV2's training and inference scripts.

Description

BackgroundMattingV2 uses standard PyTorch serialization for checkpoint management. During training, torch.save(model.state_dict(), path) saves the model weights at configurable intervals (every N steps and at epoch boundaries). For loading, model.load_state_dict(torch.load(path, map_location=device), strict=False) restores weights, with strict=False allowing partial loading when model architectures differ slightly (e.g., loading MattingBase weights into MattingRefine). Validation is performed at intervals using a valid() function that computes loss on a held-out subset and logs to TensorBoard.

Usage

Use for saving checkpoints during training and loading them for inference, export, or training resumption. The checkpoint interval is controlled by --checkpoint-interval (default 5000 for base, 2000 for refine).

Code Reference

Source Location

  • Repository: BackgroundMattingV2
  • File: train_base.py (Lines 210-213, 239-258), train_refine.py (Lines 237-241, 278-297)

Signature

# Checkpoint saving (train_base.py:L210-213)
torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth')

# Checkpoint loading (inference_images.py:L85)
model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False)

# Validation function (train_base.py:L239-258)
def valid(model, dataloader, writer, step):
    model.eval()
    loss_total = 0
    # ... compute validation loss ...
    writer.add_scalar('valid_loss', loss_total / len(dataloader), step)
    model.train()

Import

import torch
from torch.utils.tensorboard import SummaryWriter

I/O Contract

Inputs

Name Type Required Description
model.state_dict() Dict[str, Tensor] Yes Model parameter dictionary to save
path str Yes Filesystem path for checkpoint file (.pth)
map_location str/device No Device mapping for loading (default: current device)
strict bool No Whether to require exact key matching (default: False in this repo)

Outputs

Name Type Description
.pth file File Serialized model state dict at checkpoint/{model_name}/epoch-{N}-iter-{step}.pth
TensorBoard log Scalar valid_loss logged to log/{model_name}/

Usage Examples

Save During Training

import torch
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(f'log/{model_name}')

# Save at interval
if step % checkpoint_interval == 0:
    torch.save(model.state_dict(), f'checkpoint/{model_name}/epoch-{epoch}-iter-{step}.pth')

# Save at epoch end
torch.save(model.state_dict(), f'checkpoint/{model_name}/epoch-{epoch}.pth')

Load for Inference

import torch
from model import MattingRefine

model = MattingRefine(backbone='resnet50')
model.load_state_dict(
    torch.load('checkpoint/model/epoch-10.pth', map_location='cuda'),
    strict=False
)
model.eval()

Related Pages

Implements Principle

Uses Heuristics

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment