Implementation:Tensorflow Tfjs Training Utils
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, Layers_API, Training |
| Last Updated | 2026-02-10 06:00 GMT |
Overview
This module provides utility functions for handling class weights and sample weights during model training in TensorFlow.js Layers. It standardizes weight objects across single-output and multi-output models, converts class weights into per-sample weight tensors, and applies per-sample weighting to loss values. These utilities support imbalanced dataset handling by allowing models to prioritize under-represented classes.
Code Reference
Source Location
tfjs-layers/src/engine/training_utils.ts (GitHub)
Key Imports
import {argMax, clone, dispose, mul, reshape, Tensor, Tensor1D, tensor1d, tidy}
from '@tensorflow/tfjs-core';
Types
export type ClassWeight = { [classIndex: number]: number };
export type ClassWeightMap = { [outputName: string]: ClassWeight };
Functions
standardizeClassWeights
Normalizes class weight input into an array aligned with the model's output names.
export function standardizeClassWeights(
classWeight: ClassWeight | ClassWeight[] | ClassWeightMap,
outputNames: string[]): ClassWeight[]
standardizeSampleWeights
Same normalization logic as standardizeClassWeights, but for sample weights.
export function standardizeSampleWeights(
classWeight: ClassWeight | ClassWeight[] | ClassWeightMap,
outputNames: string[]): ClassWeight[]
standardizeWeights
Converts class weights into a per-sample weight tensor. Handles both direct class indices (1D targets) and one-hot encoded targets (2D). Throws if sampleWeight is provided (not yet implemented).
export async function standardizeWeights(
y: Tensor,
sampleWeight?: Tensor,
classWeight?: ClassWeight,
sampleWeightMode?: 'temporal'): Promise<Tensor>
computeWeightedLoss
Element-wise multiplication of losses by sample weights.
export function computeWeightedLoss(losses: Tensor, sampleWeights: Tensor): Tensor
I/O Contract
| Function | Input | Output |
|---|---|---|
standardizeClassWeights |
ClassWeight (single, array, or map), output names | ClassWeight[] aligned with outputs
|
standardizeWeights |
Target tensor, optional class/sample weights | Promise<Tensor> of per-sample weights
|
computeWeightedLoss |
Loss tensor, sample weight tensor | Weighted loss Tensor
|
Usage Example
import {standardizeClassWeights, standardizeWeights} from './engine/training_utils';
import * as tf from '@tensorflow/tfjs';
// Single-output model with class weights
const classWeights = {0: 1.0, 1: 5.0}; // weight class 1 more heavily
const normalized = standardizeClassWeights(classWeights, ['output']);
// normalized = [{0: 1.0, 1: 5.0}]
// Convert to per-sample weights for a batch of targets
const y = tf.tensor1d([0, 1, 1, 0]);
const weights = await standardizeWeights(y, undefined, classWeights);
// weights = Tensor1D [1.0, 5.0, 5.0, 1.0]
Related Pages
- Tensorflow_Tfjs_Regularizers - Weight regularizers applied during training
- Tensorflow_Tfjs_Layer_Utils - Model summary printing utilities