Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Tensorflow Tfjs Training Weighted Test

From Leeroopedia
Revision as of 16:52, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Tensorflow_Tfjs_Training_Weighted_Test.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Testing, Layers_API
Last Updated 2026-02-10 06:00 GMT

Overview

This test suite validates training with class weights and sample weights in TensorFlow.js Layers. Class weighting allows users to assign different importance to different classes during training, which is critical for handling imbalanced datasets. The tests cover LayersModel.fit() and LayersModel.fitDataset() with class weights for multi-class classification (one-hot and integer encodings), binary classification, multi-output models, and verifies that weighted training produces different loss values than unweighted training. Reference Python Keras code is provided for numerical verification.

Code Reference

Source Location: tfjs-layers/src/engine/training_weighted_test.ts (856 lines)

Repository: GitHub

Test Describe Blocks

  • LayersModel.fit() with classWeight - Class weighting during tensor-based training:
    • One output, multi-class, one-hot encoding
    • One output, multi-class, one-hot encoding, custom batch size
    • One output, multi-class, integer label encoding
    • One output, binary classification
    • Two outputs with different class weights
    • Memory leak verification
  • LayersModel.fitDataset() with classWeight - Class weighting during dataset-based training

I/O Contract

Inputs to tests:

  • Classification models with softmax/sigmoid output
  • Class weight maps (e.g., {0: 1, 1: 10, 2: 1}) that emphasize certain classes
  • Multi-class input data with one-hot or integer labels
  • FakeNumericDataset for dataset-based tests

Expected outputs/assertions:

  • Loss values differ from unweighted training (e.g., 4.3944 vs standard crossentropy)
  • Accuracy values match expected progression over epochs
  • Memory leak checks: memory().numTensors stable before/after training
  • Multi-output models correctly apply separate class weight maps per output

Usage Example

describeMathCPUAndWebGL2('LayersModel.fit() with classWeight', () => {
  it('One output, multi-class, one-hot encoding', async () => {
    const model = tfl.sequential();
    model.add(tfl.layers.dense({
      units: 3, inputShape: [2],
      kernelInitializer: 'zeros', activation: 'softmax'
    }));
    model.compile({
      loss: 'categoricalCrossentropy', metrics: ['acc'],
      optimizer: train.sgd(1)
    });

    const xs = tensor2d([[0, 1], [0, 2], [1, 10], [1, 20], [2, -10], [2, -20]]);
    const ys = tensor2d(
        [[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1]]);
    const history = await model.fit(
        xs, ys, {epochs: 2, classWeight: [{0: 1, 1: 10, 2: 1}]});
    expect(history.history.loss[0]).toBeCloseTo(4.3944);
  });
});

Test Coverage Summary

Category Count Details
fit() classWeight 6+ Multi-class, binary, multi-output, batch sizes
fitDataset() classWeight 3+ Dataset-based class weighting
Memory Leak Tests 2+ Tensor count verification
Label Encodings 2 One-hot and integer label formats
Test Environment CPU and WebGL2 describeMathCPUAndWebGL2

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment