Implementation:Tensorflow Tfjs LayersModel Fit
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, Optimization |
| Last Updated | 2026-02-10 00:00 GMT |
Environment:Tensorflow_Tfjs_Browser_Runtime Environment:Tensorflow_Tfjs_Node_Native_Runtime Heuristic:Tensorflow_Tfjs_Memory_Management_With_Tidy
Overview
Concrete API for training a TensorFlow.js model by iteratively updating weights to minimize the loss function, via LayersModel.fit() for in-memory tensor data and LayersModel.fitDataset() for streaming dataset pipelines.
Description
TensorFlow.js provides two methods for executing the training loop:
LayersModel.fit(x, y, args?) trains the model on in-memory tensors. It accepts feature tensors x and label tensors y along with a ModelFitArgs configuration object. Internally, fit() divides the data into batches, runs the forward pass, computes the loss, performs backpropagation, and updates weights via the compiled optimizer. This process repeats for each batch in each epoch. The method returns a Promise<History> containing per-epoch loss and metric values.
LayersModel.fitDataset(dataset, args) trains the model on a tf.data.Dataset pipeline. This is used for large datasets that do not fit in memory or for datasets that require dynamic generation. The dataset must yield batched objects with xs and ys properties. Since the dataset is consumed lazily, stepsPerEpoch must be specified (or the dataset must be finite).
Both methods are asynchronous and return a Promise. In browser environments, the yieldEvery parameter controls how often the training loop yields to the UI thread, preventing the browser from becoming unresponsive during long training runs.
Key behaviors:
- Validation split: If
validationSplitis specified, the last N% of the training data is held out for validation (data is not shuffled before splitting, so the caller should pre-shuffle if needed). - Validation data: Alternatively, explicit validation tensors can be provided via
validationData. - Shuffling: When
shuffle: true(default), the training data is shuffled at the beginning of each epoch. - Callbacks: An array of callback objects or callback configurations can be provided to hook into training events.
- Early stopping: The built-in
tf.callbacks.earlyStopping()monitors a metric and stops training when it stops improving.
Code Reference
Source
Repository: https://github.com/tensorflow/tfjs
| File | Key Locations |
|---|---|
tfjs-layers/src/engine/training.ts |
fit() method at L1464–1667
|
tfjs-layers/src/engine/training_dataset.ts |
fitDataset() method at L301–500
|
tfjs-layers/src/engine/training.ts |
History class and CustomCallbackArgs interface
|
tfjs-layers/src/callbacks.ts |
Built-in callbacks including earlyStopping
|
Signature
// In-memory training
async fit(
x: Tensor | Tensor[] | {[inputName: string]: Tensor},
y: Tensor | Tensor[] | {[inputName: string]: Tensor},
args?: ModelFitArgs
): Promise<History>
// ModelFitArgs: {
// batchSize?: number, // default 32
// epochs?: number, // number of training epochs
// verbose?: ModelLoggingVerbosity | 2,
// callbacks?: BaseCallback[] | CustomCallbackArgs | CustomCallbackArgs[],
// validationSplit?: number, // 0-1, fraction of data for validation
// validationData?: [Tensor|Tensor[], Tensor|Tensor[]]
// | [Tensor|Tensor[], Tensor|Tensor[], Tensor|Tensor[]],
// shuffle?: boolean, // default true
// classWeight?: ClassWeight | ClassWeight[] | ClassWeightMap,
// sampleWeight?: Tensor,
// initialEpoch?: number, // epoch at which to start (for resuming)
// stepsPerEpoch?: number, // batches per epoch (overrides dataset size)
// validationSteps?: number, // validation batches per epoch
// yieldEvery?: 'auto' | 'batch' | 'epoch' | number | 'never'
// }
// Dataset-based training
async fitDataset<T>(
dataset: Dataset<T>,
args: ModelFitDatasetArgs<T>
): Promise<History>
// ModelFitDatasetArgs<T>: {
// epochs?: number,
// verbose?: ModelLoggingVerbosity | 2,
// callbacks?: BaseCallback[] | CustomCallbackArgs | CustomCallbackArgs[],
// validationData?: [Tensor|Tensor[], Tensor|Tensor[]] | Dataset<T>,
// validationBatches?: number,
// validationBatchSize?: number,
// classWeight?: ClassWeight | ClassWeight[] | ClassWeightMap,
// initialEpoch?: number,
// stepsPerEpoch?: number, // required if dataset is infinite
// yieldEvery?: 'auto' | 'batch' | 'epoch' | number | 'never'
// }
Import
import * as tf from '@tensorflow/tfjs';
// Then call:
// const history = await model.fit(xs, ys, {epochs: 100, ...});
// const history = await model.fitDataset(dataset, {epochs: 100, ...});
External Dependencies
@tensorflow/tfjs-core— Provides tensor operations, gradient computation (tf.grads), and optimizer implementations used during the training loop.@tensorflow/tfjs-data— Provides theDatasetclass consumed byfitDataset().@tensorflow/tfjs-layers— Contains theLayersModelclass, training loop logic, callback system, andHistoryobject.
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| x | Tensor[] | Object | Yes (fit) | Feature tensor(s) with shape [numSamples, ...featureShape] |
| y | Tensor[] | Object | Yes (fit) | Label tensor(s) with shape [numSamples, ...labelShape] |
| dataset | Dataset |
Yes (fitDataset) | Batched dataset yielding {xs, ys} objects |
| epochs | number |
No (default 1) | Number of complete passes through the training data |
| batchSize | number |
No (default 32) | Number of samples per gradient update |
| validationSplit | number |
No | Fraction (0–1) of training data to use as validation |
| validationData | Dataset | No | Explicit validation data (overrides validationSplit) |
| callbacks | Object | No | Callback instances or configurations for training hooks |
Outputs
| Name | Type | Description |
|---|---|---|
| return | Promise<History> |
Resolves to a History object containing per-epoch training results |
| history.history.loss | number[] |
Training loss value for each epoch |
| history.history.acc | number[] |
Training accuracy for each epoch (if 'accuracy' metric was compiled) |
| history.history.val_loss | number[] |
Validation loss for each epoch (if validation data was provided) |
| history.history.val_acc | number[] |
Validation accuracy for each epoch (if validation data and 'accuracy' metric) |
Usage Examples
Basic Training with Validation Split
import * as tf from '@tensorflow/tfjs';
// Assume model is already defined and compiled
const history = await model.fit(xs, ys, {
epochs: 100,
batchSize: 32,
validationSplit: 0.2,
callbacks: tf.callbacks.earlyStopping({monitor: 'val_loss', patience: 5})
});
console.log('Final loss:', history.history.loss[history.history.loss.length - 1]);
Training with Explicit Validation Data
import * as tf from '@tensorflow/tfjs';
const trainXs = tf.tensor2d([[0, 0], [0, 1], [1, 0], [1, 1]]);
const trainYs = tf.tensor2d([[0], [1], [1], [0]]);
const valXs = tf.tensor2d([[0.1, 0.1], [0.9, 0.9]]);
const valYs = tf.tensor2d([[0], [0]]);
const history = await model.fit(trainXs, trainYs, {
epochs: 200,
batchSize: 4,
validationData: [valXs, valYs],
shuffle: true
});
Training with Custom Callbacks
import * as tf from '@tensorflow/tfjs';
const history = await model.fit(xs, ys, {
epochs: 50,
batchSize: 64,
validationSplit: 0.15,
callbacks: {
onEpochEnd: (epoch, logs) => {
console.log(`Epoch ${epoch}: loss = ${logs.loss.toFixed(4)}, ` +
`val_loss = ${logs.val_loss.toFixed(4)}`);
},
onTrainEnd: () => {
console.log('Training complete.');
}
}
});
Training with Dataset (Streaming)
import * as tf from '@tensorflow/tfjs';
// Create a batched dataset
const trainDataset = tf.data.generator(function* () {
for (let i = 0; i < 1000; i++) {
const features = Array.from({length: 10}, () => Math.random());
const label = features[0] + features[1] > 1 ? [1, 0] : [0, 1];
yield {value: {xs: features, ys: label}};
}
}).shuffle(200).batch(32);
const valDataset = tf.data.generator(function* () {
for (let i = 0; i < 200; i++) {
const features = Array.from({length: 10}, () => Math.random());
const label = features[0] + features[1] > 1 ? [1, 0] : [0, 1];
yield {value: {xs: features, ys: label}};
}
}).batch(32);
const history = await model.fitDataset(trainDataset, {
epochs: 20,
validationData: valDataset,
callbacks: tf.callbacks.earlyStopping({monitor: 'val_loss', patience: 3})
});
Resuming Training from a Checkpoint
import * as tf from '@tensorflow/tfjs';
// First training phase: epochs 0-49
const history1 = await model.fit(xs, ys, {
epochs: 50,
batchSize: 32,
validationSplit: 0.2
});
// Resume training: epochs 50-99
// initialEpoch ensures correct epoch numbering in logs
const history2 = await model.fit(xs, ys, {
epochs: 100,
initialEpoch: 50,
batchSize: 32,
validationSplit: 0.2
});
Browser-Friendly Training with UI Yielding
import * as tf from '@tensorflow/tfjs';
// yieldEvery controls how often the training loop yields to the browser event loop
const history = await model.fit(xs, ys, {
epochs: 100,
batchSize: 32,
yieldEvery: 'epoch', // yield control back to browser after each epoch
callbacks: {
onEpochEnd: (epoch, logs) => {
// Update UI with progress (this runs because we yielded)
document.getElementById('status').textContent =
`Epoch ${epoch + 1}/100 - loss: ${logs.loss.toFixed(4)}`;
}
}
});
Related Pages
Implements Principle
Environments
- Environment:Tensorflow_Tfjs_Browser_Runtime -- Browser runtime (WebGL / WebGPU / WASM / CPU backends)
- Environment:Tensorflow_Tfjs_Node_Native_Runtime -- Node.js native runtime (TensorFlow C binding)
Heuristics
- Heuristic:Tensorflow_Tfjs_Memory_Management_With_Tidy -- Wrap predictions in tf.tidy() to prevent memory leaks