Implementation:Microsoft Onnxruntime OnnxSparseTensor
| Knowledge Sources | Description |
|---|---|
| Source File | java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java |
| Repository | Microsoft/onnxruntime |
Domains
- Machine Learning Runtime
- Sparse Tensor Representation
- JNI Native Interop
Overview
OnnxSparseTensor is a public final class that wraps an ORT sparse tensor value. It supports COO, CSR/CSC (CSRC), and Block Sparse formats. The class provides factory methods to create sparse tensors from Java-side representations, and accessor methods to retrieve indices, inner indices, values, and their shapes. It also contains inner classes (SparseTensor, COOTensor, CSRCTensor, BlockSparseTensor) and the SparseTensorType enum.
Description
The OnnxSparseTensor class extends OnnxTensorLike and represents sparse tensor data in ONNX Runtime. Key features:
- Sparse formats: Supports COO (Coordinate), CSRC (Compressed Sparse Row/Column), and Block Sparse formats via the
SparseTensorTypeenum. - Factory method:
createSparseTensor(OrtEnvironment, SparseTensor)constructs an ORT sparse tensor from a Java-side representation using direct buffers. - Value extraction:
getValue()returns a format-specificSparseTensorsubclass (COOTensor, CSRCTensor, or BlockSparseTensor). - Buffer access: Methods like
getIndicesBuffer(),getInnerIndicesBuffer(), andgetValuesBuffer()return copies of the underlying native data as Java NIO buffers. - FP16/BF16 support: Values in fp16 or bf16 format are automatically upcast to fp32 FloatBuffers.
- Inner classes:
SparseTensor<T>is the abstract base;COOTensoruses LongBuffer indices;CSRCTensoruses outer and inner LongBuffer indices;BlockSparseTensoruses IntBuffer indices.
Usage
Create sparse tensors using the static factory method and retrieve values after inference.
Code Reference
Source Location
// File: java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java
// Package: ai.onnxruntime
Signature
public final class OnnxSparseTensor extends OnnxTensorLike {
// Factory method
public static <T extends Buffer> OnnxSparseTensor createSparseTensor(
OrtEnvironment env, SparseTensor<T> tensor) throws OrtException;
// Accessors
public OnnxValueType getType();
public SparseTensor<? extends Buffer> getValue() throws OrtException;
public SparseTensorType getSparseTensorType();
public Buffer getIndicesBuffer();
public LongBuffer getInnerIndicesBuffer();
public Buffer getValuesBuffer();
public long[] getIndicesShape();
public long[] getInnerIndicesShape();
public long[] getValuesShape();
public synchronized void close();
// Enums and inner classes
public enum SparseTensorType { UNDEFINED, COO, CSRC, BLOCK_SPARSE }
public abstract static class SparseTensor<T extends Buffer> { ... }
public static final class COOTensor extends SparseTensor<LongBuffer> { ... }
public static final class CSRCTensor extends SparseTensor<LongBuffer> { ... }
public static final class BlockSparseTensor extends SparseTensor<IntBuffer> { ... }
}
Import
import ai.onnxruntime.OnnxSparseTensor;
import ai.onnxruntime.OnnxSparseTensor.SparseTensorType;
import ai.onnxruntime.OnnxSparseTensor.COOTensor;
import ai.onnxruntime.OnnxSparseTensor.CSRCTensor;
import ai.onnxruntime.OnnxSparseTensor.BlockSparseTensor;
I/O Contract
Inputs
| Name | Type | Description |
|---|---|---|
| env | OrtEnvironment | The ONNX Runtime environment |
| tensor | SparseTensor<T> | Java-side sparse tensor with indices, values, shapes, and sparsity type |
Outputs
| Name | Type | Description |
|---|---|---|
| OnnxSparseTensor | OnnxSparseTensor | Wraps a native ORT sparse tensor value |
| getValue() | SparseTensor<? extends Buffer> | Returns COOTensor, CSRCTensor, or BlockSparseTensor depending on format |
| getIndicesBuffer() | Buffer | Copy of the indices (LongBuffer for COO/CSRC, IntBuffer for Block Sparse) |
| getValuesBuffer() | Buffer | Copy of the data values buffer |
Usage Examples
import ai.onnxruntime.*;
import ai.onnxruntime.OnnxSparseTensor.*;
import java.nio.*;
// Create a COO sparse tensor for a 3x3 matrix with 2 non-zero elements
OrtEnvironment env = OrtEnvironment.getEnvironment();
long[] denseShape = new long[]{3, 3};
long[] indicesShape = new long[]{2, 2};
LongBuffer indices = LongBuffer.wrap(new long[]{0, 1, 1, 2}); // (0,1) and (1,2)
FloatBuffer values = FloatBuffer.wrap(new float[]{3.0f, 7.0f});
COOTensor cooTensor = new COOTensor(indices, indicesShape, values, denseShape,
OnnxJavaType.FLOAT, 2);
try (OnnxSparseTensor sparseTensor = OnnxSparseTensor.createSparseTensor(env, cooTensor)) {
System.out.println("Sparse type: " + sparseTensor.getSparseTensorType());
System.out.println("Indices shape: " + java.util.Arrays.toString(sparseTensor.getIndicesShape()));
SparseTensor<? extends Buffer> result = sparseTensor.getValue();
System.out.println("Non-zero elements: " + result.getNumNonZeroElements());
}
Related Pages
- OnnxTensor.java - Dense tensor counterpart
- TensorInfo.java - Tensor metadata including shape and type
- ai_onnxruntime_OnnxSparseTensor.c - JNI native implementation
- OrtEnvironment.java - Environment required for tensor creation