Implementation:VainF Torch Pruning Serialization
| Knowledge Sources | |
|---|---|
| Domains | Pruning, Model_Persistence, Serialization |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Concrete tool for saving and loading structurally pruned PyTorch models with full metadata preservation, provided by the Torch-Pruning library.
Description
The serialization module extends standard PyTorch serialization to capture pruning-related metadata that torch.nn.Module.state_dict() does not preserve. When a model is structurally pruned, attributes like in_channels, out_channels, num_features, and normalized_shape are modified in-place. Standard state_dict() only saves parameter tensors and registered buffers, losing these structural changes.
This module provides:
- state_dict(model): Captures both the full
__dict__of every module and all non-private, non-callable, non-Parameter, non-Tensor attributes. Returns a dictionary withfull_state_dictandattributionskeys. - load_state_dict(model, state_dict): Restores both the
__dict__state and individual attributes for each named module, fully reconstructing the pruned model state. - save / load: Convenience aliases for
torch.saveandtorch.load.
Usage
Import this module when you need to save a structurally pruned model to disk and later reload it with all pruning modifications intact. Use tp.state_dict() instead of model.state_dict() and tp.load_state_dict() instead of model.load_state_dict() whenever working with pruned models.
Code Reference
Source Location
- Repository: VainF_Torch_Pruning
- File: torch_pruning/serialization.py
- Lines: 1-43
Signature
def state_dict(model: torch.nn.Module) -> dict:
"""Returns a dictionary containing the state, attributions of a module.
Returns:
dict with keys:
'full_state_dict': {name: module.__dict__} for each named module
'attributions': {name: {attr_name: attr_value}} for non-private,
non-callable, non-Parameter, non-Tensor attributes
"""
def load_state_dict(model: torch.nn.Module, state_dict: dict) -> torch.nn.Module:
"""Load a model given a state_dict.
Args:
model: The model to restore state into.
state_dict: Dictionary produced by state_dict().
Returns:
The model with restored state.
"""
# Convenience aliases
load = torch.load
save = torch.save
Import
import torch_pruning as tp
# Or directly:
from torch_pruning.serialization import state_dict, load_state_dict
I/O Contract
state_dict
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | torch.nn.Module | Yes | The pruned model to serialize |
Outputs
| Name | Type | Description |
|---|---|---|
| return | dict | Dictionary with 'full_state_dict' (module __dict__ copies) and 'attributions' (non-private non-callable attributes per module) |
load_state_dict
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | torch.nn.Module | Yes | The model structure to restore state into (must match architecture) |
| state_dict | dict | Yes | Dictionary produced by state_dict() with 'full_state_dict' and 'attributions' keys |
Outputs
| Name | Type | Description |
|---|---|---|
| return | torch.nn.Module | The model with fully restored state including pruning metadata |
Usage Examples
Saving and Loading a Pruned Model
import torch
import torch.nn as nn
import torch_pruning as tp
# 1. Create and prune a model
model = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 128, 3, padding=1),
)
example_input = torch.randn(1, 3, 32, 32)
DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_input)
group = DG.get_pruning_group(model[0], tp.prune_conv_out_channels, idxs=[0, 1, 2])
group.prune()
# 2. Save using tp.state_dict (captures pruning metadata)
state = tp.state_dict(model)
torch.save(state, "pruned_model.pth")
# 3. Load later into a fresh model of the same architecture
model_new = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 128, 3, padding=1),
)
saved_state = torch.load("pruned_model.pth")
tp.load_state_dict(model_new, saved_state)
# model_new now has the same pruned structure and weights
print(model_new[0].out_channels) # 61 (64 - 3 pruned)