Implementation:Tensorflow Tfjs LayersModel Compile And Fit For Transfer
Metadata
| Field | Value |
|---|---|
| Implementation Name | Tensorflow Tfjs LayersModel Compile And Fit For Transfer |
| Library | TensorFlow.js |
| Domains | Transfer_Learning, Optimization |
| Type | API Doc (transfer learning with EarlyStopping) |
| Implements | Principle:Tensorflow_Tfjs_Fine_Tuning |
| Source | TensorFlow.js |
| Last Updated | 2026-02-10 00:00 GMT |
Environment:Tensorflow_Tfjs_Browser_Runtime
Overview
LayersModel.compile and LayersModel.fit are the TensorFlow.js APIs for configuring and executing model training. In the context of transfer learning, these methods are used with specific configurations: low learning rates, early stopping callbacks, and validation monitoring to fine-tune the transfer model without destroying pretrained representations. The tf.callbacks.earlyStopping factory provides a built-in callback that halts training when a monitored metric stops improving.
Description
The compile method configures the model for training by specifying the optimizer, loss function, and metrics. The fit method executes the training loop, iterating over the data for the specified number of epochs. For transfer learning, these methods are used with careful configuration:
- compile is called with a low learning rate optimizer (e.g., tf.train.adam(0.0001)) and an appropriate loss function.
- fit is called with early stopping to prevent overfitting on small target datasets.
- After optional unfreezing of base layers, compile must be called again to reconfigure the optimizer with the updated set of trainable parameters.
Code Reference
Source files:
- compile: tfjs-layers/src/training.ts (Lines 583-657)
- fit: tfjs-layers/src/training.ts (Lines 1464-1667)
- EarlyStopping: tfjs-layers/src/callbacks.ts (Lines 101-206)
- tf.callbacks.earlyStopping: tfjs-layers/src/callbacks.ts (Lines 251-253)
API Signatures
// Compile the model for training
compile(args: ModelCompileArgs): void
// Train the model
async fit(
x: Tensor | Tensor[] | {[inputName: string]: Tensor},
y: Tensor | Tensor[] | {[inputName: string]: Tensor},
args?: ModelFitArgs
): Promise<History>
// Early stopping callback factory
tf.callbacks.earlyStopping(args?: EarlyStoppingCallbackArgs): EarlyStopping
interface EarlyStoppingCallbackArgs {
monitor?: string; // default 'val_loss'
minDelta?: number; // default 0
patience?: number; // default 0
verbose?: number;
mode?: 'auto' | 'min' | 'max';
baseline?: number;
restoreBestWeights?: boolean;
}
Parameters
compile
| Parameter | Type | Required | Description |
|---|---|---|---|
| optimizer | Optimizer | Yes | The optimizer instance or name. For transfer learning, use tf.train.adam(0.0001) or similar with a low learning rate. |
| loss | string[] | LossOrMetricFn | Yes | Loss function. Common choices: categoricalCrossentropy (multi-class), binaryCrossentropy (binary), meanSquaredError (regression). |
| metrics | MetricsSpec | No | Metrics to monitor during training. Common: ['accuracy']. |
fit
| Parameter | Type | Required | Description |
|---|---|---|---|
| x | Tensor[] | Yes | Training input data. |
| y | Tensor[] | Yes | Training target data. |
| args.epochs | number |
No | Maximum number of training epochs. Default: 1. For transfer learning, set to 50+ and rely on early stopping. |
| args.batchSize | number |
No | Number of samples per gradient update. Default: 32. |
| args.validationSplit | number |
No | Fraction of training data to use for validation (0 to 1). Enables early stopping monitoring. |
| args.callbacks | Callback[] | No | Callbacks for training events. Use tf.callbacks.earlyStopping() for transfer learning. |
tf.callbacks.earlyStopping
| Parameter | Type | Required | Description |
|---|---|---|---|
| monitor | string |
No | Metric to monitor. Default: val_loss. Use val_accuracy for classification tasks. |
| minDelta | number |
No | Minimum change to qualify as an improvement. Default: 0. |
| patience | number |
No | Number of epochs with no improvement before stopping. Default: 0. For transfer learning, use 3-10. |
| verbose | number |
No | Verbosity level (0 = silent, 1 = messages). |
| mode | 'min' | 'max' | No | Whether the monitored metric should be minimized (min) or maximized (max). Default: auto (inferred from metric name). |
| baseline | number |
No | Baseline value for the monitored metric. Training stops if the metric does not improve beyond this baseline. |
| restoreBestWeights | boolean |
No | Whether to restore model weights from the epoch with the best monitored value. Highly recommended for transfer learning. |
Return Values
| Method | Return Type | Description |
|---|---|---|
| compile | void |
Configures the model in-place. No return value. |
| fit | Promise<History> |
A History object containing training and validation loss/metric values for each epoch. |
| tf.callbacks.earlyStopping | EarlyStopping |
A callback instance to be passed to fit(). |
I/O Contract
| Direction | Description |
|---|---|
| Inputs | A compiled transfer model, task-specific training data (tensors), low learning rate configuration, and early stopping parameters. |
| Outputs | A Promise<History> containing per-epoch training and validation metrics (loss, accuracy, etc.). The model's trainable weights are updated in-place. |
| Side Effects | Modifies the model's trainable weights through gradient descent. GPU memory is allocated for intermediate tensors during training. Early stopping may terminate training before the specified number of epochs. |
| Errors | Throws if the model has not been compiled, if input/output tensor shapes do not match the model's expected shapes, or if validationSplit is not in [0, 1]. |
Usage Examples
Example 1: Compile with Low Learning Rate and Train with Early Stopping
// Compile with low learning rate for fine-tuning
transferModel.compile({
optimizer: tf.train.adam(0.0001), // Low LR for fine-tuning
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
// Train with early stopping
const history = await transferModel.fit(trainXs, trainYs, {
epochs: 50,
batchSize: 32,
validationSplit: 0.2,
callbacks: tf.callbacks.earlyStopping({
monitor: 'val_loss',
patience: 5,
restoreBestWeights: true
})
});
Example 2: Two-Phase Training
// Phase 1: Train head only (base layers frozen)
baseModel.layers.forEach(layer => { layer.trainable = false; });
transferModel.compile({
optimizer: tf.train.adam(0.001),
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
const phase1History = await transferModel.fit(trainXs, trainYs, {
epochs: 20,
batchSize: 32,
validationSplit: 0.2,
callbacks: tf.callbacks.earlyStopping({
monitor: 'val_accuracy',
patience: 3,
mode: 'max',
restoreBestWeights: true
})
});
console.log('Phase 1 final val_accuracy:',
phase1History.history.val_accuracy.slice(-1)[0]);
// Phase 2: Unfreeze last few base layers, fine-tune with lower LR
for (let i = baseModel.layers.length - 10; i < baseModel.layers.length; i++) {
baseModel.layers[i].trainable = true;
}
// IMPORTANT: Recompile after changing trainable flags
transferModel.compile({
optimizer: tf.train.adam(0.00001), // Much lower LR
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
const phase2History = await transferModel.fit(trainXs, trainYs, {
epochs: 30,
batchSize: 32,
validationSplit: 0.2,
callbacks: tf.callbacks.earlyStopping({
monitor: 'val_loss',
patience: 5,
restoreBestWeights: true
})
});
Example 3: Training with Explicit Validation Data
// Use explicit validation data instead of validationSplit
transferModel.compile({
optimizer: tf.train.adam(0.0001),
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
const history = await transferModel.fit(trainXs, trainYs, {
epochs: 50,
batchSize: 16,
validationData: [valXs, valYs],
callbacks: tf.callbacks.earlyStopping({
monitor: 'val_loss',
patience: 7,
minDelta: 0.001,
restoreBestWeights: true
})
});
// Inspect training history
console.log('Training loss:', history.history.loss);
console.log('Validation loss:', history.history.val_loss);
console.log('Training accuracy:', history.history.acc);
console.log('Validation accuracy:', history.history.val_acc);
Example 4: Binary Classification Fine-Tuning
// Fine-tuning for binary classification
transferModel.compile({
optimizer: tf.train.adam(0.0001),
loss: 'binaryCrossentropy',
metrics: ['accuracy']
});
const history = await transferModel.fit(trainXs, trainYs, {
epochs: 30,
batchSize: 32,
validationSplit: 0.15,
callbacks: [
tf.callbacks.earlyStopping({
monitor: 'val_loss',
patience: 5,
restoreBestWeights: true
})
]
});
Usage
In the transfer learning workflow, compile and fit are used after the transfer model has been constructed and base layers have been frozen:
- Compile the model with a low learning rate and appropriate loss/metrics.
- Fit the model on the target dataset with early stopping enabled.
- Optionally unfreeze base layers, recompile with an even lower learning rate, and fit again (two-phase training).
- Inspect the History object to understand training dynamics.
Critical: Always call compile() after changing trainable flags on any layer. The compile step rebuilds the optimizer's internal state to reflect the current set of trainable parameters.
Related Pages
- Principle:Tensorflow_Tfjs_Fine_Tuning -- The principle this implementation realizes
- Implementation:Tensorflow_Tfjs_Layer_Trainable_Setter -- Freezing/unfreezing layers before compilation
- Implementation:Tensorflow_Tfjs_Tf_Model_Functional -- Building the transfer model to be trained
- Implementation:Tensorflow_Tfjs_LayersModel_Evaluate_And_Save_For_Transfer -- Evaluating and saving the fine-tuned model
Environments
- Environment:Tensorflow_Tfjs_Browser_Runtime -- Browser runtime (WebGL / WebGPU / WASM / CPU backends)