Principle:Tensorflow Tfjs Pretrained Weight Loading
Summary
Pretrained weight loading is the process of loading pre-trained weights into a constructed model architecture from stored artifacts. This is a library-agnostic concept: weight loading deserializes binary weight data and assigns values to the corresponding layers and parameters in an existing model graph.
Theory
Weight loading is a critical step in transfer learning and model deployment. It enables reuse of weights trained on large datasets and expensive hardware by loading them into a matching architecture at inference time.
The process involves five key steps:
- Fetch weight manifest: Download or read a JSON file describing weight names, shapes, and data types.
- Download binary weight shards: Retrieve one or more binary files containing the serialized weight values.
- Deserialize binary data: Convert the raw binary data into typed arrays (e.g.,
Float32Array) matching the declared dtypes. - Match weight names to model parameters: Align each weight tensor from the manifest with the corresponding layer parameter in the model graph.
- Assign values to layer variables: Set the model's trainable and non-trainable variables to the loaded weight values.
Weight Manifest Format
The weight manifest is a JSON structure that describes the layout of weights across binary shard files:
| Field | Description |
|---|---|
weightsManifest |
Array of weight groups, each specifying a set of binary file paths and weight specs |
paths |
List of binary shard file names for this group |
weights |
Array of weight specifications (name, shape, dtype) in this group |
Strict vs. Non-Strict Loading
| Mode | Behavior | Use Case |
|---|---|---|
| Strict (default) | All model weights must have matching entries in the loaded data; extra or missing weights cause errors | Full model loading, ensuring integrity |
| Non-strict | Allows partial loading where some weights may be missing or extra | Transfer learning, fine-tuning from a related model |
Key Properties
- Architecture-agnostic storage: Weights are stored as named tensors independent of the framework's internal graph representation.
- Sharded loading: Large models can split weights across multiple binary files for parallel downloading.
- Progress tracking: Loading progress can be monitored via callback functions for user feedback.
- Cross-platform: Weight files can be generated on one platform (e.g., Python TensorFlow) and loaded on another (e.g., TensorFlow.js in the browser).
Implementation
Implementation:Tensorflow_Tfjs_Tf_LoadLayersModel