Implementation:Tensorflow Tfjs LayersModel Compile
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, Optimization |
| Last Updated | 2026-02-10 00:00 GMT |
Environment:Tensorflow_Tfjs_Browser_Runtime Heuristic:Tensorflow_Tfjs_Backend_Selection_Strategy
Overview
Concrete API for configuring a TensorFlow.js LayersModel (or Sequential) for training by specifying the optimizer, loss function, and optional metrics via the compile() method.
Description
The LayersModel.compile() method binds training configuration to a model instance. It accepts a ModelCompileArgs object and mutates the model in-place, setting internal properties that the training loop (fit()) requires.
When compile() is called, the following internal operations occur:
- Optimizer resolution — If the optimizer is specified as a string (e.g.,
'adam'), it is resolved to a concreteOptimizerinstance viagetOptimizer(). The resolved optimizer is stored onthis.optimizer_. If anOptimizerinstance is passed directly, it is used as-is. - Loss function resolution — If the loss is specified as a string (e.g.,
'categoricalCrossentropy'), it is resolved to a concrete loss function via the loss function registry. The resolved loss functions are stored inthis.lossFunctions(one per output). - Metrics resolution — String metric names are resolved to metric functions. Metrics are stored in
this.metricsTensorsand are evaluated during training and validation but do not contribute to gradient computation. - isCompiled flag — The
this.isCompiledflag is set totrue, allowingfit()to proceed.
String-to-optimizer mapping:
| String | Optimizer Class | Default Learning Rate |
|---|---|---|
'sgd' |
tf.train.sgd |
0.01 |
'adam' |
tf.train.adam |
0.001 |
'rmsprop' |
tf.train.rmsprop |
0.001 |
'adagrad' |
tf.train.adagrad |
0.01 |
'adadelta' |
tf.train.adadelta |
0.001 |
'adamax' |
tf.train.adamax |
0.002 |
For fine-grained control over optimizer hyperparameters (learning rate, beta values, epsilon), pass an Optimizer instance directly rather than a string.
Code Reference
Source
Repository: https://github.com/tensorflow/tfjs
| File | Key Locations |
|---|---|
tfjs-layers/src/engine/training.ts |
compile() method at L583–657
|
tfjs-layers/src/engine/training.ts |
getOptimizer() helper for string-to-optimizer resolution
|
tfjs-layers/src/losses.ts |
Loss function implementations and registry |
tfjs-layers/src/metrics.ts |
Metric function implementations and registry |
Signature
compile(args: ModelCompileArgs): void
// ModelCompileArgs interface:
// {
// optimizer: string | Optimizer
// — String name (e.g. 'adam', 'sgd', 'rmsprop') or a tf.train.* instance
//
// loss: string | string[] | LossOrMetricFn | LossOrMetricFn[]
// | {[outputName: string]: string}
// — Loss function name (e.g. 'categoricalCrossentropy', 'meanSquaredError')
// or a custom function (y_true, y_pred) => Scalar
// or an array/map for multi-output models
//
// metrics?: string | LossOrMetricFn
// | Array<string | LossOrMetricFn>
// | {[outputName: string]: string | LossOrMetricFn}
// — Optional metric names (e.g. 'accuracy') or custom metric functions
// }
Import
import * as tf from '@tensorflow/tfjs';
// Then call:
// model.compile({optimizer: 'adam', loss: 'categoricalCrossentropy', metrics: ['accuracy']});
External Dependencies
@tensorflow/tfjs-core— ProvidesOptimizerbase class and concrete optimizer implementations (train.sgd,train.adam,train.rmsprop, etc.).@tensorflow/tfjs-layers— Contains theLayersModelclass, loss function registry, and metric function registry.
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| optimizer | Optimizer | Yes | Optimization algorithm name or instance |
| loss | LossOrMetricFn | Array | Object | Yes | Loss function name, custom function, or per-output mapping |
| metrics | Array | Object | No | Metric names or functions to evaluate during training |
Outputs
| Name | Type | Description |
|---|---|---|
| return | void |
The method mutates the model in-place; no return value |
| internal: optimizer_ | Optimizer |
Resolved optimizer instance stored on the model |
| internal: lossFunctions | LossOrMetricFn[] |
Resolved loss functions stored on the model |
| internal: metricsTensors | Array |
Resolved metrics stored on the model |
| internal: isCompiled | boolean |
Set to true, enabling fit() to proceed
|
Usage Examples
Basic Compilation with String Arguments
import * as tf from '@tensorflow/tfjs';
const model = tf.sequential();
model.add(tf.layers.dense({units: 128, activation: 'relu', inputShape: [784]}));
model.add(tf.layers.dense({units: 10, activation: 'softmax'}));
model.compile({
optimizer: 'adam',
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
Compilation with Custom Optimizer Instance
import * as tf from '@tensorflow/tfjs';
const model = tf.sequential();
model.add(tf.layers.dense({units: 64, activation: 'relu', inputShape: [20]}));
model.add(tf.layers.dense({units: 1, activation: 'sigmoid'}));
// Create optimizer with custom hyperparameters
const optimizer = tf.train.adam(0.0001); // learning rate = 0.0001
model.compile({
optimizer: optimizer,
loss: 'binaryCrossentropy',
metrics: ['accuracy']
});
Regression Model Compilation
import * as tf from '@tensorflow/tfjs';
const model = tf.sequential();
model.add(tf.layers.dense({units: 64, activation: 'relu', inputShape: [13]}));
model.add(tf.layers.dense({units: 32, activation: 'relu'}));
model.add(tf.layers.dense({units: 1})); // linear activation for regression
model.compile({
optimizer: tf.train.rmsprop(0.001),
loss: 'meanSquaredError',
metrics: ['mse']
});
Recompiling to Change Optimizer
import * as tf from '@tensorflow/tfjs';
const model = tf.sequential();
model.add(tf.layers.dense({units: 128, activation: 'relu', inputShape: [784]}));
model.add(tf.layers.dense({units: 10, activation: 'softmax'}));
// Phase 1: Train with Adam
model.compile({
optimizer: tf.train.adam(0.001),
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
// ... train for some epochs ...
// Phase 2: Fine-tune with SGD at lower learning rate
// Existing weights are preserved; only the optimizer changes
model.compile({
optimizer: tf.train.sgd(0.0001),
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
// ... continue training ...
Related Pages
Implements Principle
Environments
- Environment:Tensorflow_Tfjs_Browser_Runtime -- Browser runtime (WebGL / WebGPU / WASM / CPU backends)
Heuristics
- Heuristic:Tensorflow_Tfjs_Backend_Selection_Strategy -- Choose the optimal compute backend for your environment