Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Tensorflow Tfjs LayersModel Compile And Fit For Transfer

From Leeroopedia


Metadata

Field Value
Implementation Name Tensorflow Tfjs LayersModel Compile And Fit For Transfer
Library TensorFlow.js
Domains Transfer_Learning, Optimization
Type API Doc (transfer learning with EarlyStopping)
Implements Principle:Tensorflow_Tfjs_Fine_Tuning
Source TensorFlow.js
Last Updated 2026-02-10 00:00 GMT

Environment:Tensorflow_Tfjs_Browser_Runtime

Overview

LayersModel.compile and LayersModel.fit are the TensorFlow.js APIs for configuring and executing model training. In the context of transfer learning, these methods are used with specific configurations: low learning rates, early stopping callbacks, and validation monitoring to fine-tune the transfer model without destroying pretrained representations. The tf.callbacks.earlyStopping factory provides a built-in callback that halts training when a monitored metric stops improving.

Description

The compile method configures the model for training by specifying the optimizer, loss function, and metrics. The fit method executes the training loop, iterating over the data for the specified number of epochs. For transfer learning, these methods are used with careful configuration:

  • compile is called with a low learning rate optimizer (e.g., tf.train.adam(0.0001)) and an appropriate loss function.
  • fit is called with early stopping to prevent overfitting on small target datasets.
  • After optional unfreezing of base layers, compile must be called again to reconfigure the optimizer with the updated set of trainable parameters.

Code Reference

Source files:

  • compile: tfjs-layers/src/training.ts (Lines 583-657)
  • fit: tfjs-layers/src/training.ts (Lines 1464-1667)
  • EarlyStopping: tfjs-layers/src/callbacks.ts (Lines 101-206)
  • tf.callbacks.earlyStopping: tfjs-layers/src/callbacks.ts (Lines 251-253)

API Signatures

// Compile the model for training
compile(args: ModelCompileArgs): void

// Train the model
async fit(
  x: Tensor | Tensor[] | {[inputName: string]: Tensor},
  y: Tensor | Tensor[] | {[inputName: string]: Tensor},
  args?: ModelFitArgs
): Promise<History>
// Early stopping callback factory
tf.callbacks.earlyStopping(args?: EarlyStoppingCallbackArgs): EarlyStopping

interface EarlyStoppingCallbackArgs {
  monitor?: string;            // default 'val_loss'
  minDelta?: number;           // default 0
  patience?: number;           // default 0
  verbose?: number;
  mode?: 'auto' | 'min' | 'max';
  baseline?: number;
  restoreBestWeights?: boolean;
}

Parameters

compile

Parameter Type Required Description
optimizer Optimizer Yes The optimizer instance or name. For transfer learning, use tf.train.adam(0.0001) or similar with a low learning rate.
loss string[] | LossOrMetricFn Yes Loss function. Common choices: categoricalCrossentropy (multi-class), binaryCrossentropy (binary), meanSquaredError (regression).
metrics MetricsSpec No Metrics to monitor during training. Common: ['accuracy'].

fit

Parameter Type Required Description
x Tensor[] Yes Training input data.
y Tensor[] Yes Training target data.
args.epochs number No Maximum number of training epochs. Default: 1. For transfer learning, set to 50+ and rely on early stopping.
args.batchSize number No Number of samples per gradient update. Default: 32.
args.validationSplit number No Fraction of training data to use for validation (0 to 1). Enables early stopping monitoring.
args.callbacks Callback[] No Callbacks for training events. Use tf.callbacks.earlyStopping() for transfer learning.

tf.callbacks.earlyStopping

Parameter Type Required Description
monitor string No Metric to monitor. Default: val_loss. Use val_accuracy for classification tasks.
minDelta number No Minimum change to qualify as an improvement. Default: 0.
patience number No Number of epochs with no improvement before stopping. Default: 0. For transfer learning, use 3-10.
verbose number No Verbosity level (0 = silent, 1 = messages).
mode 'min' | 'max' No Whether the monitored metric should be minimized (min) or maximized (max). Default: auto (inferred from metric name).
baseline number No Baseline value for the monitored metric. Training stops if the metric does not improve beyond this baseline.
restoreBestWeights boolean No Whether to restore model weights from the epoch with the best monitored value. Highly recommended for transfer learning.

Return Values

Method Return Type Description
compile void Configures the model in-place. No return value.
fit Promise<History> A History object containing training and validation loss/metric values for each epoch.
tf.callbacks.earlyStopping EarlyStopping A callback instance to be passed to fit().

I/O Contract

Direction Description
Inputs A compiled transfer model, task-specific training data (tensors), low learning rate configuration, and early stopping parameters.
Outputs A Promise<History> containing per-epoch training and validation metrics (loss, accuracy, etc.). The model's trainable weights are updated in-place.
Side Effects Modifies the model's trainable weights through gradient descent. GPU memory is allocated for intermediate tensors during training. Early stopping may terminate training before the specified number of epochs.
Errors Throws if the model has not been compiled, if input/output tensor shapes do not match the model's expected shapes, or if validationSplit is not in [0, 1].

Usage Examples

Example 1: Compile with Low Learning Rate and Train with Early Stopping

// Compile with low learning rate for fine-tuning
transferModel.compile({
  optimizer: tf.train.adam(0.0001),  // Low LR for fine-tuning
  loss: 'categoricalCrossentropy',
  metrics: ['accuracy']
});

// Train with early stopping
const history = await transferModel.fit(trainXs, trainYs, {
  epochs: 50,
  batchSize: 32,
  validationSplit: 0.2,
  callbacks: tf.callbacks.earlyStopping({
    monitor: 'val_loss',
    patience: 5,
    restoreBestWeights: true
  })
});

Example 2: Two-Phase Training

// Phase 1: Train head only (base layers frozen)
baseModel.layers.forEach(layer => { layer.trainable = false; });

transferModel.compile({
  optimizer: tf.train.adam(0.001),
  loss: 'categoricalCrossentropy',
  metrics: ['accuracy']
});

const phase1History = await transferModel.fit(trainXs, trainYs, {
  epochs: 20,
  batchSize: 32,
  validationSplit: 0.2,
  callbacks: tf.callbacks.earlyStopping({
    monitor: 'val_accuracy',
    patience: 3,
    mode: 'max',
    restoreBestWeights: true
  })
});

console.log('Phase 1 final val_accuracy:',
  phase1History.history.val_accuracy.slice(-1)[0]);

// Phase 2: Unfreeze last few base layers, fine-tune with lower LR
for (let i = baseModel.layers.length - 10; i < baseModel.layers.length; i++) {
  baseModel.layers[i].trainable = true;
}

// IMPORTANT: Recompile after changing trainable flags
transferModel.compile({
  optimizer: tf.train.adam(0.00001),  // Much lower LR
  loss: 'categoricalCrossentropy',
  metrics: ['accuracy']
});

const phase2History = await transferModel.fit(trainXs, trainYs, {
  epochs: 30,
  batchSize: 32,
  validationSplit: 0.2,
  callbacks: tf.callbacks.earlyStopping({
    monitor: 'val_loss',
    patience: 5,
    restoreBestWeights: true
  })
});

Example 3: Training with Explicit Validation Data

// Use explicit validation data instead of validationSplit
transferModel.compile({
  optimizer: tf.train.adam(0.0001),
  loss: 'categoricalCrossentropy',
  metrics: ['accuracy']
});

const history = await transferModel.fit(trainXs, trainYs, {
  epochs: 50,
  batchSize: 16,
  validationData: [valXs, valYs],
  callbacks: tf.callbacks.earlyStopping({
    monitor: 'val_loss',
    patience: 7,
    minDelta: 0.001,
    restoreBestWeights: true
  })
});

// Inspect training history
console.log('Training loss:', history.history.loss);
console.log('Validation loss:', history.history.val_loss);
console.log('Training accuracy:', history.history.acc);
console.log('Validation accuracy:', history.history.val_acc);

Example 4: Binary Classification Fine-Tuning

// Fine-tuning for binary classification
transferModel.compile({
  optimizer: tf.train.adam(0.0001),
  loss: 'binaryCrossentropy',
  metrics: ['accuracy']
});

const history = await transferModel.fit(trainXs, trainYs, {
  epochs: 30,
  batchSize: 32,
  validationSplit: 0.15,
  callbacks: [
    tf.callbacks.earlyStopping({
      monitor: 'val_loss',
      patience: 5,
      restoreBestWeights: true
    })
  ]
});

Usage

In the transfer learning workflow, compile and fit are used after the transfer model has been constructed and base layers have been frozen:

  1. Compile the model with a low learning rate and appropriate loss/metrics.
  2. Fit the model on the target dataset with early stopping enabled.
  3. Optionally unfreeze base layers, recompile with an even lower learning rate, and fit again (two-phase training).
  4. Inspect the History object to understand training dynamics.

Critical: Always call compile() after changing trainable flags on any layer. The compile step rebuilds the optimizer's internal state to reflect the current set of trainable parameters.

Related Pages

Environments

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment