Principle:VainF Torch Pruning Pruned Model Serialization
| Knowledge Sources | |
|---|---|
| Domains | Pruning, Model_Persistence, Serialization |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
A serialization strategy that captures both parameter tensors and structural metadata of pruned neural network modules, enabling faithful save/load of models whose architecture has been modified by structural pruning.
Description
Standard PyTorch serialization (model.state_dict()) only saves parameter tensors and registered buffers. This is sufficient for models whose architecture is fixed, but structural pruning modifies the architecture itself: convolution layers have fewer channels, normalization layers have different feature counts, and various module attributes (in_channels, out_channels, num_features, normalized_shape) are changed in-place.
Pruned Model Serialization solves this by extending the serialization scope to include:
- Full module state: The entire
__dict__of each module, capturing all internal state including modified attributes. - Non-parameter attributes: All public, non-callable, non-Parameter, non-Tensor attributes discovered via introspection, which captures structural metadata that standard serialization misses.
This ensures that a pruned model can be saved to disk and later loaded with all structural modifications intact, without requiring the pruning operation to be replayed.
Usage
Use this principle whenever you need to persist a structurally pruned model. Standard torch.save(model.state_dict()) will lose structural metadata, causing shape mismatches when loading. This extended serialization approach should be used instead of the standard PyTorch serialization for any model that has undergone structural (not just weight-level) pruning.
Theoretical Basis
The core insight is that a PyTorch module's state has two layers:
- Tensor state: Parameters and buffers (captured by standard
state_dict()) - Structural state: Non-tensor attributes that define the module's architecture (NOT captured by standard
state_dict())
Pseudo-code Logic:
# Abstract algorithm description (NOT real implementation)
# Standard serialization (insufficient for pruned models):
standard_state = {name: param for name, param in model.named_parameters()}
# Extended serialization for pruned models:
extended_state = {}
for name, module in model.named_modules():
# Capture everything in the module's __dict__
extended_state[name] = {
'full_dict': module.__dict__.copy(),
'attributes': {
attr: getattr(module, attr)
for attr in dir(module)
if is_structural_attribute(attr)
}
}
# Restoration:
for name, module in model.named_modules():
module.__dict__.update(extended_state[name]['full_dict'])
for attr, value in extended_state[name]['attributes'].items():
setattr(module, attr, value)
The attribute filter must exclude:
- Private attributes (prefixed with
_) - Callable attributes (methods)
- Parameter and Tensor attributes (already handled by tensor state)
- Special attributes (like
T_destination) that should not be serialized