Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:VainF Torch Pruning Serialization

From Leeroopedia


Knowledge Sources
Domains Pruning, Model_Persistence, Serialization
Last Updated 2026-02-08 00:00 GMT

Overview

Concrete tool for saving and loading structurally pruned PyTorch models with full metadata preservation, provided by the Torch-Pruning library.

Description

The serialization module extends standard PyTorch serialization to capture pruning-related metadata that torch.nn.Module.state_dict() does not preserve. When a model is structurally pruned, attributes like in_channels, out_channels, num_features, and normalized_shape are modified in-place. Standard state_dict() only saves parameter tensors and registered buffers, losing these structural changes.

This module provides:

  • state_dict(model): Captures both the full __dict__ of every module and all non-private, non-callable, non-Parameter, non-Tensor attributes. Returns a dictionary with full_state_dict and attributions keys.
  • load_state_dict(model, state_dict): Restores both the __dict__ state and individual attributes for each named module, fully reconstructing the pruned model state.
  • save / load: Convenience aliases for torch.save and torch.load.

Usage

Import this module when you need to save a structurally pruned model to disk and later reload it with all pruning modifications intact. Use tp.state_dict() instead of model.state_dict() and tp.load_state_dict() instead of model.load_state_dict() whenever working with pruned models.

Code Reference

Source Location

Signature

def state_dict(model: torch.nn.Module) -> dict:
    """Returns a dictionary containing the state, attributions of a module.

    Returns:
        dict with keys:
            'full_state_dict': {name: module.__dict__} for each named module
            'attributions': {name: {attr_name: attr_value}} for non-private,
                non-callable, non-Parameter, non-Tensor attributes
    """

def load_state_dict(model: torch.nn.Module, state_dict: dict) -> torch.nn.Module:
    """Load a model given a state_dict.

    Args:
        model: The model to restore state into.
        state_dict: Dictionary produced by state_dict().

    Returns:
        The model with restored state.
    """

# Convenience aliases
load = torch.load
save = torch.save

Import

import torch_pruning as tp

# Or directly:
from torch_pruning.serialization import state_dict, load_state_dict

I/O Contract

state_dict

Inputs

Name Type Required Description
model torch.nn.Module Yes The pruned model to serialize

Outputs

Name Type Description
return dict Dictionary with 'full_state_dict' (module __dict__ copies) and 'attributions' (non-private non-callable attributes per module)

load_state_dict

Inputs

Name Type Required Description
model torch.nn.Module Yes The model structure to restore state into (must match architecture)
state_dict dict Yes Dictionary produced by state_dict() with 'full_state_dict' and 'attributions' keys

Outputs

Name Type Description
return torch.nn.Module The model with fully restored state including pruning metadata

Usage Examples

Saving and Loading a Pruned Model

import torch
import torch.nn as nn
import torch_pruning as tp

# 1. Create and prune a model
model = nn.Sequential(
    nn.Conv2d(3, 64, 3, padding=1),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.Conv2d(64, 128, 3, padding=1),
)

example_input = torch.randn(1, 3, 32, 32)
DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_input)
group = DG.get_pruning_group(model[0], tp.prune_conv_out_channels, idxs=[0, 1, 2])
group.prune()

# 2. Save using tp.state_dict (captures pruning metadata)
state = tp.state_dict(model)
torch.save(state, "pruned_model.pth")

# 3. Load later into a fresh model of the same architecture
model_new = nn.Sequential(
    nn.Conv2d(3, 64, 3, padding=1),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.Conv2d(64, 128, 3, padding=1),
)
saved_state = torch.load("pruned_model.pth")
tp.load_state_dict(model_new, saved_state)

# model_new now has the same pruned structure and weights
print(model_new[0].out_channels)  # 61 (64 - 3 pruned)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment