Implementation:Tensorflow Tfjs Training Weighted Test
| Knowledge Sources | |
|---|---|
| Domains | Testing, Layers_API |
| Last Updated | 2026-02-10 06:00 GMT |
Overview
This test suite validates training with class weights and sample weights in TensorFlow.js Layers. Class weighting allows users to assign different importance to different classes during training, which is critical for handling imbalanced datasets. The tests cover LayersModel.fit() and LayersModel.fitDataset() with class weights for multi-class classification (one-hot and integer encodings), binary classification, multi-output models, and verifies that weighted training produces different loss values than unweighted training. Reference Python Keras code is provided for numerical verification.
Code Reference
Source Location: tfjs-layers/src/engine/training_weighted_test.ts (856 lines)
Repository: GitHub
Test Describe Blocks
LayersModel.fit() with classWeight- Class weighting during tensor-based training:- One output, multi-class, one-hot encoding
- One output, multi-class, one-hot encoding, custom batch size
- One output, multi-class, integer label encoding
- One output, binary classification
- Two outputs with different class weights
- Memory leak verification
LayersModel.fitDataset() with classWeight- Class weighting during dataset-based training
I/O Contract
Inputs to tests:
- Classification models with softmax/sigmoid output
- Class weight maps (e.g.,
{0: 1, 1: 10, 2: 1}) that emphasize certain classes - Multi-class input data with one-hot or integer labels
- FakeNumericDataset for dataset-based tests
Expected outputs/assertions:
- Loss values differ from unweighted training (e.g.,
4.3944vs standard crossentropy) - Accuracy values match expected progression over epochs
- Memory leak checks:
memory().numTensorsstable before/after training - Multi-output models correctly apply separate class weight maps per output
Usage Example
describeMathCPUAndWebGL2('LayersModel.fit() with classWeight', () => {
it('One output, multi-class, one-hot encoding', async () => {
const model = tfl.sequential();
model.add(tfl.layers.dense({
units: 3, inputShape: [2],
kernelInitializer: 'zeros', activation: 'softmax'
}));
model.compile({
loss: 'categoricalCrossentropy', metrics: ['acc'],
optimizer: train.sgd(1)
});
const xs = tensor2d([[0, 1], [0, 2], [1, 10], [1, 20], [2, -10], [2, -20]]);
const ys = tensor2d(
[[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1]]);
const history = await model.fit(
xs, ys, {epochs: 2, classWeight: [{0: 1, 1: 10, 2: 1}]});
expect(history.history.loss[0]).toBeCloseTo(4.3944);
});
});
Test Coverage Summary
| Category | Count | Details |
|---|---|---|
| fit() classWeight | 6+ | Multi-class, binary, multi-output, batch sizes |
| fitDataset() classWeight | 3+ | Dataset-based class weighting |
| Memory Leak Tests | 2+ | Tensor count verification |
| Label Encodings | 2 | One-hot and integer label formats |
| Test Environment | CPU and WebGL2 | describeMathCPUAndWebGL2
|