Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Tensorflow Tfjs LayersModel Fit

From Leeroopedia


Knowledge Sources
Domains Deep_Learning, Optimization
Last Updated 2026-02-10 00:00 GMT

Environment:Tensorflow_Tfjs_Browser_Runtime Environment:Tensorflow_Tfjs_Node_Native_Runtime Heuristic:Tensorflow_Tfjs_Memory_Management_With_Tidy

Overview

Concrete API for training a TensorFlow.js model by iteratively updating weights to minimize the loss function, via LayersModel.fit() for in-memory tensor data and LayersModel.fitDataset() for streaming dataset pipelines.

Description

TensorFlow.js provides two methods for executing the training loop:

LayersModel.fit(x, y, args?) trains the model on in-memory tensors. It accepts feature tensors x and label tensors y along with a ModelFitArgs configuration object. Internally, fit() divides the data into batches, runs the forward pass, computes the loss, performs backpropagation, and updates weights via the compiled optimizer. This process repeats for each batch in each epoch. The method returns a Promise<History> containing per-epoch loss and metric values.

LayersModel.fitDataset(dataset, args) trains the model on a tf.data.Dataset pipeline. This is used for large datasets that do not fit in memory or for datasets that require dynamic generation. The dataset must yield batched objects with xs and ys properties. Since the dataset is consumed lazily, stepsPerEpoch must be specified (or the dataset must be finite).

Both methods are asynchronous and return a Promise. In browser environments, the yieldEvery parameter controls how often the training loop yields to the UI thread, preventing the browser from becoming unresponsive during long training runs.

Key behaviors:

  • Validation split: If validationSplit is specified, the last N% of the training data is held out for validation (data is not shuffled before splitting, so the caller should pre-shuffle if needed).
  • Validation data: Alternatively, explicit validation tensors can be provided via validationData.
  • Shuffling: When shuffle: true (default), the training data is shuffled at the beginning of each epoch.
  • Callbacks: An array of callback objects or callback configurations can be provided to hook into training events.
  • Early stopping: The built-in tf.callbacks.earlyStopping() monitors a metric and stops training when it stops improving.

Code Reference

Source

Repository: https://github.com/tensorflow/tfjs

File Key Locations
tfjs-layers/src/engine/training.ts fit() method at L1464–1667
tfjs-layers/src/engine/training_dataset.ts fitDataset() method at L301–500
tfjs-layers/src/engine/training.ts History class and CustomCallbackArgs interface
tfjs-layers/src/callbacks.ts Built-in callbacks including earlyStopping

Signature

// In-memory training
async fit(
  x: Tensor | Tensor[] | {[inputName: string]: Tensor},
  y: Tensor | Tensor[] | {[inputName: string]: Tensor},
  args?: ModelFitArgs
): Promise<History>

// ModelFitArgs: {
//   batchSize?: number,          // default 32
//   epochs?: number,             // number of training epochs
//   verbose?: ModelLoggingVerbosity | 2,
//   callbacks?: BaseCallback[] | CustomCallbackArgs | CustomCallbackArgs[],
//   validationSplit?: number,    // 0-1, fraction of data for validation
//   validationData?: [Tensor|Tensor[], Tensor|Tensor[]]
//                  | [Tensor|Tensor[], Tensor|Tensor[], Tensor|Tensor[]],
//   shuffle?: boolean,           // default true
//   classWeight?: ClassWeight | ClassWeight[] | ClassWeightMap,
//   sampleWeight?: Tensor,
//   initialEpoch?: number,       // epoch at which to start (for resuming)
//   stepsPerEpoch?: number,      // batches per epoch (overrides dataset size)
//   validationSteps?: number,    // validation batches per epoch
//   yieldEvery?: 'auto' | 'batch' | 'epoch' | number | 'never'
// }

// Dataset-based training
async fitDataset<T>(
  dataset: Dataset<T>,
  args: ModelFitDatasetArgs<T>
): Promise<History>

// ModelFitDatasetArgs<T>: {
//   epochs?: number,
//   verbose?: ModelLoggingVerbosity | 2,
//   callbacks?: BaseCallback[] | CustomCallbackArgs | CustomCallbackArgs[],
//   validationData?: [Tensor|Tensor[], Tensor|Tensor[]] | Dataset<T>,
//   validationBatches?: number,
//   validationBatchSize?: number,
//   classWeight?: ClassWeight | ClassWeight[] | ClassWeightMap,
//   initialEpoch?: number,
//   stepsPerEpoch?: number,      // required if dataset is infinite
//   yieldEvery?: 'auto' | 'batch' | 'epoch' | number | 'never'
// }

Import

import * as tf from '@tensorflow/tfjs';

// Then call:
// const history = await model.fit(xs, ys, {epochs: 100, ...});
// const history = await model.fitDataset(dataset, {epochs: 100, ...});

External Dependencies

  • @tensorflow/tfjs-core — Provides tensor operations, gradient computation (tf.grads), and optimizer implementations used during the training loop.
  • @tensorflow/tfjs-data — Provides the Dataset class consumed by fitDataset().
  • @tensorflow/tfjs-layers — Contains the LayersModel class, training loop logic, callback system, and History object.

I/O Contract

Inputs

Name Type Required Description
x Tensor[] | Object Yes (fit) Feature tensor(s) with shape [numSamples, ...featureShape]
y Tensor[] | Object Yes (fit) Label tensor(s) with shape [numSamples, ...labelShape]
dataset Dataset Yes (fitDataset) Batched dataset yielding {xs, ys} objects
epochs number No (default 1) Number of complete passes through the training data
batchSize number No (default 32) Number of samples per gradient update
validationSplit number No Fraction (0–1) of training data to use as validation
validationData Dataset No Explicit validation data (overrides validationSplit)
callbacks Object No Callback instances or configurations for training hooks

Outputs

Name Type Description
return Promise<History> Resolves to a History object containing per-epoch training results
history.history.loss number[] Training loss value for each epoch
history.history.acc number[] Training accuracy for each epoch (if 'accuracy' metric was compiled)
history.history.val_loss number[] Validation loss for each epoch (if validation data was provided)
history.history.val_acc number[] Validation accuracy for each epoch (if validation data and 'accuracy' metric)

Usage Examples

Basic Training with Validation Split

import * as tf from '@tensorflow/tfjs';

// Assume model is already defined and compiled
const history = await model.fit(xs, ys, {
  epochs: 100,
  batchSize: 32,
  validationSplit: 0.2,
  callbacks: tf.callbacks.earlyStopping({monitor: 'val_loss', patience: 5})
});

console.log('Final loss:', history.history.loss[history.history.loss.length - 1]);

Training with Explicit Validation Data

import * as tf from '@tensorflow/tfjs';

const trainXs = tf.tensor2d([[0, 0], [0, 1], [1, 0], [1, 1]]);
const trainYs = tf.tensor2d([[0], [1], [1], [0]]);
const valXs = tf.tensor2d([[0.1, 0.1], [0.9, 0.9]]);
const valYs = tf.tensor2d([[0], [0]]);

const history = await model.fit(trainXs, trainYs, {
  epochs: 200,
  batchSize: 4,
  validationData: [valXs, valYs],
  shuffle: true
});

Training with Custom Callbacks

import * as tf from '@tensorflow/tfjs';

const history = await model.fit(xs, ys, {
  epochs: 50,
  batchSize: 64,
  validationSplit: 0.15,
  callbacks: {
    onEpochEnd: (epoch, logs) => {
      console.log(`Epoch ${epoch}: loss = ${logs.loss.toFixed(4)}, ` +
                  `val_loss = ${logs.val_loss.toFixed(4)}`);
    },
    onTrainEnd: () => {
      console.log('Training complete.');
    }
  }
});

Training with Dataset (Streaming)

import * as tf from '@tensorflow/tfjs';

// Create a batched dataset
const trainDataset = tf.data.generator(function* () {
  for (let i = 0; i < 1000; i++) {
    const features = Array.from({length: 10}, () => Math.random());
    const label = features[0] + features[1] > 1 ? [1, 0] : [0, 1];
    yield {value: {xs: features, ys: label}};
  }
}).shuffle(200).batch(32);

const valDataset = tf.data.generator(function* () {
  for (let i = 0; i < 200; i++) {
    const features = Array.from({length: 10}, () => Math.random());
    const label = features[0] + features[1] > 1 ? [1, 0] : [0, 1];
    yield {value: {xs: features, ys: label}};
  }
}).batch(32);

const history = await model.fitDataset(trainDataset, {
  epochs: 20,
  validationData: valDataset,
  callbacks: tf.callbacks.earlyStopping({monitor: 'val_loss', patience: 3})
});

Resuming Training from a Checkpoint

import * as tf from '@tensorflow/tfjs';

// First training phase: epochs 0-49
const history1 = await model.fit(xs, ys, {
  epochs: 50,
  batchSize: 32,
  validationSplit: 0.2
});

// Resume training: epochs 50-99
// initialEpoch ensures correct epoch numbering in logs
const history2 = await model.fit(xs, ys, {
  epochs: 100,
  initialEpoch: 50,
  batchSize: 32,
  validationSplit: 0.2
});

Browser-Friendly Training with UI Yielding

import * as tf from '@tensorflow/tfjs';

// yieldEvery controls how often the training loop yields to the browser event loop
const history = await model.fit(xs, ys, {
  epochs: 100,
  batchSize: 32,
  yieldEvery: 'epoch',  // yield control back to browser after each epoch
  callbacks: {
    onEpochEnd: (epoch, logs) => {
      // Update UI with progress (this runs because we yielded)
      document.getElementById('status').textContent =
        `Epoch ${epoch + 1}/100 - loss: ${logs.loss.toFixed(4)}`;
    }
  }
});

Related Pages

Implements Principle

Environments

Heuristics

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment