Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Tensorflow Tfjs Training Utils

From Leeroopedia


Knowledge Sources
Domains Deep_Learning, Layers_API, Training
Last Updated 2026-02-10 06:00 GMT

Overview

This module provides utility functions for handling class weights and sample weights during model training in TensorFlow.js Layers. It standardizes weight objects across single-output and multi-output models, converts class weights into per-sample weight tensors, and applies per-sample weighting to loss values. These utilities support imbalanced dataset handling by allowing models to prioritize under-represented classes.

Code Reference

Source Location

tfjs-layers/src/engine/training_utils.ts (GitHub)

Key Imports

import {argMax, clone, dispose, mul, reshape, Tensor, Tensor1D, tensor1d, tidy}
    from '@tensorflow/tfjs-core';

Types

export type ClassWeight = { [classIndex: number]: number };
export type ClassWeightMap = { [outputName: string]: ClassWeight };

Functions

standardizeClassWeights

Normalizes class weight input into an array aligned with the model's output names.

export function standardizeClassWeights(
    classWeight: ClassWeight | ClassWeight[] | ClassWeightMap,
    outputNames: string[]): ClassWeight[]

standardizeSampleWeights

Same normalization logic as standardizeClassWeights, but for sample weights.

export function standardizeSampleWeights(
    classWeight: ClassWeight | ClassWeight[] | ClassWeightMap,
    outputNames: string[]): ClassWeight[]

standardizeWeights

Converts class weights into a per-sample weight tensor. Handles both direct class indices (1D targets) and one-hot encoded targets (2D). Throws if sampleWeight is provided (not yet implemented).

export async function standardizeWeights(
    y: Tensor,
    sampleWeight?: Tensor,
    classWeight?: ClassWeight,
    sampleWeightMode?: 'temporal'): Promise<Tensor>

computeWeightedLoss

Element-wise multiplication of losses by sample weights.

export function computeWeightedLoss(losses: Tensor, sampleWeights: Tensor): Tensor

I/O Contract

Function Input Output
standardizeClassWeights ClassWeight (single, array, or map), output names ClassWeight[] aligned with outputs
standardizeWeights Target tensor, optional class/sample weights Promise<Tensor> of per-sample weights
computeWeightedLoss Loss tensor, sample weight tensor Weighted loss Tensor

Usage Example

import {standardizeClassWeights, standardizeWeights} from './engine/training_utils';
import * as tf from '@tensorflow/tfjs';

// Single-output model with class weights
const classWeights = {0: 1.0, 1: 5.0};  // weight class 1 more heavily
const normalized = standardizeClassWeights(classWeights, ['output']);
// normalized = [{0: 1.0, 1: 5.0}]

// Convert to per-sample weights for a batch of targets
const y = tf.tensor1d([0, 1, 1, 0]);
const weights = await standardizeWeights(y, undefined, classWeights);
// weights = Tensor1D [1.0, 5.0, 5.0, 1.0]

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment