Principle:Tensorflow Tfjs Model Training
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, Optimization |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Model training is the iterative process of updating neural network weights to minimize the loss function through repeated cycles of forward propagation, loss computation, backpropagation, and gradient-based parameter updates across multiple epochs and batches.
Description
Training is the core computational phase of the machine learning pipeline. It takes a compiled model (with defined architecture, optimizer, and loss function) and training data, then iteratively adjusts the model's parameters to reduce the discrepancy between predictions and ground truth labels.
The training process is organized into two nested loops:
- Outer loop (epochs): One epoch is a complete pass through the entire training dataset. Multiple epochs allow the model to see every training example multiple times, progressively refining its parameters.
- Inner loop (batches): Within each epoch, the dataset is divided into mini-batches. Each batch triggers one forward pass, one backward pass, and one parameter update. Using mini-batches rather than the full dataset provides a balance between the noisy gradients of single-sample updates (SGD) and the expensive gradients of full-batch updates.
Usage
Model training is performed:
- After model architecture definition and compilation.
- With properly prepared training data (correctly shaped tensors or streaming datasets).
- Optionally with validation data to monitor generalization performance.
- Optionally with callbacks to implement early stopping, learning rate scheduling, checkpointing, or custom logging.
The practitioner must decide several key hyperparameters:
- Number of epochs — Too few leads to underfitting; too many leads to overfitting.
- Batch size — Affects gradient noise, training speed, and memory usage. Typical values: 16, 32, 64, 128, 256.
- Validation split — Fraction of training data reserved for validation (typically 10-20%).
Theoretical Basis
The Training Loop Step by Step
For each mini-batch within each epoch, the following operations occur in sequence:
- Forward pass: The input batch
xis propagated through all layers to produce predictionsy_hat = f_w(x). Each layer applies its transformation and stores intermediate activations needed for backpropagation. - Loss computation: The loss function computes a scalar value
L = loss(y, y_hat)measuring the error between predictions and true labels. - Backpropagation: Starting from the loss, gradients are computed for every trainable parameter in the network via the chain rule:
dL/dw_i = dL/da_n * da_n/da_{n-1} * ... * da_{i+1}/dw_i, wherea_iare layer activations. - Parameter update: The optimizer applies its update rule to each parameter:
w = w - alpha * optimizer_transform(gradient).
for epoch = 1 to num_epochs:
for each batch (x_batch, y_batch) in training_data:
y_hat = model.forward(x_batch) // Forward pass
loss = loss_function(y_batch, y_hat) // Compute loss
gradients = backpropagate(loss) // Compute gradients
optimizer.update(model.weights, gradients) // Update weights
// Validation phase (no gradient computation)
for each batch (x_val, y_val) in validation_data:
y_hat_val = model.forward(x_val)
val_loss = loss_function(y_val, y_hat_val)
accumulate_metrics(y_val, y_hat_val)
report_epoch_metrics()
Epochs and Convergence
Training proceeds for a fixed number of epochs or until a convergence criterion is met. The loss typically follows a characteristic pattern:
| Phase | Training Loss | Validation Loss | Interpretation |
|---|---|---|---|
| Early training | Decreasing rapidly | Decreasing rapidly | Model is learning useful patterns |
| Mid training | Decreasing steadily | Decreasing steadily | Model continues to improve |
| Late training | Still decreasing | Plateaus or increases | Overfitting begins — stop here |
Early stopping monitors validation loss and halts training when it stops improving, preventing overfitting.
Batch Size Trade-offs
| Batch Size | Gradient Quality | Memory Usage | Training Speed | Generalization |
|---|---|---|---|---|
| Small (16-32) | Noisy but regularizing | Low | Slower convergence | Often better |
| Medium (64-128) | Balanced | Moderate | Good throughput | Good |
| Large (256-512) | Smooth but may overfit | High | Fast per-epoch | May be worse |
Validation and Generalization
Validation data is never used for gradient computation or parameter updates. Its sole purpose is to provide an unbiased estimate of how the model will perform on unseen data. The gap between training metrics and validation metrics indicates the degree of overfitting:
- Training accuracy >> Validation accuracy — The model has memorized training data but does not generalize.
- Training accuracy ~ Validation accuracy — The model generalizes well.
- Training accuracy < Validation accuracy — Unusual; may indicate data leakage or a very small training set.
Callbacks
Callbacks are hooks that execute at specific points during training (start/end of epoch, start/end of batch). Common callbacks:
- Early stopping — Monitors a metric (typically
val_loss) and stops training when it stops improving for a specified number of epochs (patience). - Learning rate scheduling — Adjusts the learning rate during training (e.g., reduce on plateau, cosine annealing).
- Model checkpointing — Saves model weights at regular intervals or when a metric improves.
- Custom logging — Records training history for visualization.
The History Object
Training returns a history object containing per-epoch values for all tracked quantities (loss, metrics, validation loss, validation metrics). This data enables post-training analysis: plotting learning curves, identifying overfitting onset, and comparing training runs.