Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Principle:VainF Torch Pruning Pruned Model Serialization

From Leeroopedia


Knowledge Sources
Domains Pruning, Model_Persistence, Serialization
Last Updated 2026-02-08 00:00 GMT

Overview

A serialization strategy that captures both parameter tensors and structural metadata of pruned neural network modules, enabling faithful save/load of models whose architecture has been modified by structural pruning.

Description

Standard PyTorch serialization (model.state_dict()) only saves parameter tensors and registered buffers. This is sufficient for models whose architecture is fixed, but structural pruning modifies the architecture itself: convolution layers have fewer channels, normalization layers have different feature counts, and various module attributes (in_channels, out_channels, num_features, normalized_shape) are changed in-place.

Pruned Model Serialization solves this by extending the serialization scope to include:

  1. Full module state: The entire __dict__ of each module, capturing all internal state including modified attributes.
  2. Non-parameter attributes: All public, non-callable, non-Parameter, non-Tensor attributes discovered via introspection, which captures structural metadata that standard serialization misses.

This ensures that a pruned model can be saved to disk and later loaded with all structural modifications intact, without requiring the pruning operation to be replayed.

Usage

Use this principle whenever you need to persist a structurally pruned model. Standard torch.save(model.state_dict()) will lose structural metadata, causing shape mismatches when loading. This extended serialization approach should be used instead of the standard PyTorch serialization for any model that has undergone structural (not just weight-level) pruning.

Theoretical Basis

The core insight is that a PyTorch module's state has two layers:

  1. Tensor state: Parameters and buffers (captured by standard state_dict())
  2. Structural state: Non-tensor attributes that define the module's architecture (NOT captured by standard state_dict())

Pseudo-code Logic:

# Abstract algorithm description (NOT real implementation)

# Standard serialization (insufficient for pruned models):
standard_state = {name: param for name, param in model.named_parameters()}

# Extended serialization for pruned models:
extended_state = {}
for name, module in model.named_modules():
    # Capture everything in the module's __dict__
    extended_state[name] = {
        'full_dict': module.__dict__.copy(),
        'attributes': {
            attr: getattr(module, attr)
            for attr in dir(module)
            if is_structural_attribute(attr)
        }
    }

# Restoration:
for name, module in model.named_modules():
    module.__dict__.update(extended_state[name]['full_dict'])
    for attr, value in extended_state[name]['attributes'].items():
        setattr(module, attr, value)

The attribute filter must exclude:

  • Private attributes (prefixed with _)
  • Callable attributes (methods)
  • Parameter and Tensor attributes (already handled by tensor state)
  • Special attributes (like T_destination) that should not be serialized

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment