Implementation:Tensorflow Tfjs CPU TopK Impl
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, CPU_Backend |
| Last Updated | 2026-02-10 06:00 GMT |
Overview
The topKImpl function finds the top k largest values and their indices along the last dimension of a tensor. It is shared between the WebGL and CPU backends. The implementation uses the Floyd-Rivest selection algorithm for efficient partial sorting when k is smaller than the input size, and optionally performs a full sort of the top-k elements.
Code Reference
Source Location
tfjs-backend-cpu/src/kernels/TopK_impl.ts (GitHub)
Signature
export function topKImpl<T extends Tensor, R extends Rank>(
x: TypedArray, xShape: number[], xDtype: NumericDataType, k: number,
sorted: boolean):
[TensorBuffer<R, NumericDataType>, TensorBuffer<R, 'int32'>]
Imports
import {buffer, NumericDataType, Rank, ShapeMap, Tensor, TensorBuffer, TypedArray, util} from '@tensorflow/tfjs-core';
I/O Contract
Inputs
| Name | Type | Description |
|---|---|---|
x |
TypedArray |
Input tensor data (flattened) |
xShape |
number[] |
Shape of the input tensor |
xDtype |
NumericDataType |
Numeric data type of the input |
k |
number |
Number of top elements to return |
sorted |
boolean |
Whether to return the top-k values in sorted order |
Output
Returns a tuple of two TensorBuffer objects:
- Values buffer: shape
[...xShape[:-1], k]with the same dtype as input, containing top-k values. - Indices buffer: shape
[...xShape[:-1], k]withint32dtype, containing indices of top-k values.
Internal Functions
Pair Type
type Pair = {
value: number,
index: number
};
comparePair
Comparator function used for sorting: orders by descending value, then by ascending index for ties.
select (Floyd-Rivest Algorithm)
function select(array: Pair[], k: number, left = 0, right = array.length - 1)
Partitions the array so that elements smaller than the (k+1)th smallest element are to the left and larger elements to the right. Uses the Floyd-Rivest algorithm for efficient O(n) expected time selection.
Algorithm
- Reshapes the input conceptually into 2D:
[batch, lastDim]wherebatch = x.length / lastDim. - For each batch:
- Creates an array of
{value, index}pairs from the input slice. - If
k < array.length, uses the Floyd-Rivestselectalgorithm to partition the array so the top-k elements are in the first k positions. - If
sortedis true, sorts the top-k elements by descending value. - Writes the top-k values and indices into the output buffers.
- Creates an array of
- Reshapes the output back to
[...originalShape[:-1], k].
Usage Example
import * as tf from '@tensorflow/tfjs-core';
import '@tensorflow/tfjs-backend-cpu';
const x = tf.tensor1d([5, 2, 8, 1, 9, 3]);
const {values, indices} = tf.topk(x, 3);
values.print(); // [9, 8, 5]
indices.print(); // [4, 2, 0]