Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Tensorflow Tfjs CPU SparseSegmentReduction Impl

From Leeroopedia


Knowledge Sources
Domains Deep_Learning, CPU_Backend
Last Updated 2026-02-10 06:00 GMT

Overview

The sparseSegmentReductionImpl function performs sparse segment reduction (sum or mean) on a tensor. It gathers rows from the input using the provided indices, then reduces (sums or averages) those rows according to segment IDs. This implements both SparseSegmentSum and SparseSegmentMean operations depending on the isMean flag.

Code Reference

Source Location

tfjs-backend-cpu/src/kernels/SparseSegmentReduction_impl.ts (GitHub)

Signature

export function sparseSegmentReductionImpl(
    input: TypedArray, inputShape: number[], inputDType: DataType,
    indices: TypedArray, segmentIds: TypedArray, isMean = false,
    defaultValue = 0): [TypedArray, number[]]

Imports

import {backend_util, DataType, TypedArray, util} from '@tensorflow/tfjs-core';

I/O Contract

Inputs

Name Type Description
input TypedArray Input tensor data (flattened)
inputShape number[] Shape of the input tensor
inputDType DataType Data type of the input
indices TypedArray Indices into the first dimension of input to gather
segmentIds TypedArray Sorted segment IDs corresponding to each index
isMean boolean If true, computes mean instead of sum (default: false)
defaultValue number Fill value for gaps between segment IDs (default: 0)

Output

Returns a tuple [output, outputShape]:

  • output (TypedArray): Reduced values with shape [outputRows, ...inputShape.slice(1)].
  • outputShape (number[]): Shape of the output, where outputShape[0] = lastSegmentId + 1.

Algorithm

  1. Flattens the input to 2D: [inputShape[0], numCol] where numCol = input.length / inputShape[0].
  2. Determines the number of output rows from the last segment ID.
  3. Iterates through indices, grouping consecutive entries with the same segment ID.
  4. For each segment group:
    1. Validates that segment IDs are non-negative and strictly non-decreasing.
    2. Fills any gap between the previous segment and current segment with the defaultValue.
    3. Accumulates values from all gathered rows in the segment.
    4. If isMean is true, divides each accumulated column by the segment size.
  5. Fills any remaining gap at the end with the defaultValue.

Usage Example

import {sparseSegmentReductionImpl} from './SparseSegmentReduction_impl';

const input = new Float32Array([1, 2, 3, 4, 5, 6, 7, 8]);
const inputShape = [4, 2];
const indices = new Int32Array([0, 1, 2, 3]);
const segmentIds = new Int32Array([0, 0, 1, 1]);

// Sparse segment sum
const [output, outShape] = sparseSegmentReductionImpl(
    input, inputShape, 'float32', indices, segmentIds);
// output: [4, 6, 12, 14], outShape: [2, 2]

// Sparse segment mean
const [meanOutput, meanShape] = sparseSegmentReductionImpl(
    input, inputShape, 'float32', indices, segmentIds, true);
// meanOutput: [2, 3, 6, 7], meanShape: [2, 2]

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment