Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft Onnxruntime OrtTrainingSession JNI

From Leeroopedia


Knowledge Sources Description
Source File java/src/main/native/ai_onnxruntime_OrtTrainingSession.c
Repository Microsoft/onnxruntime

Domains

  • JNI Native Bridge
  • Training Session Management
  • Model Training Operations

Overview

ai_onnxruntime_OrtTrainingSession.c is the JNI native implementation file for the OrtTrainingSession Java class. It provides native methods for creating training sessions, executing training and evaluation steps, running optimizer steps, managing learning rates and schedulers, querying train/eval input/output names, exporting trained models for inference, and releasing training session resources.

Description

Key function groups:

  • Session creation: createTrainingSession creates a training session from checkpoint, train model path, optional eval model path, and optional optimizer model path. Uses platform-specific string handling: wide-char (wchar_t*) on Win32 with copyAndPad, UTF-8 on Unix. The copyAndPad helper function null-terminates Java string chars for Windows APIs.
  • Name queries: getTrainInputNames, getTrainOutputNames, getEvalInputNames, getEvalOutputNames return String[] arrays of input/output names for the training and evaluation models via the OrtTrainingApi.
  • Training step: trainStep executes a forward+backward pass. It marshals input names and OrtValue handles into C arrays, calls trainApi->TrainStep, converts output OrtValues to Java OnnxValue objects, and returns a boolean array indicating ORT memory ownership.
  • Evaluation step: evalStep executes a forward-only evaluation pass. The structure mirrors trainStep but calls trainApi->EvalStep.
  • Gradient reset: lazyResetGrad lazily resets accumulated gradients via trainApi->LazyResetGrad.
  • Optimizer: optimizerStep applies a gradient update using trainApi->OptimizerStep with optional run options.
  • Learning rate: setLearningRate and getLearningRate control the optimizer learning rate.
  • LR scheduling: registerLinearLRScheduler registers a linear warmup+decay scheduler with warmup steps, total steps, and initial learning rate. schedulerStep advances the scheduler by one step.
  • Random seed: setSeed sets the global random seed for reproducibility (static method via class handle).
  • Model export: exportModelForInference exports the training model for inference with specified output names to a file path. Uses platform-specific string handling for the output path.
  • Session cleanup: closeSession releases the native OrtTrainingSession via trainApi->ReleaseTrainingSession.

Both trainStep and evalStep use goto-based cleanup patterns to ensure all allocated arrays (input names, output names, Java strings, value pointers) are properly freed on both success and failure paths.

Code Reference

Source Location

// File: java/src/main/native/ai_onnxruntime_OrtTrainingSession.c

Signature

// Session lifecycle
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtTrainingSession_createTrainingSession
    (JNIEnv*, jclass, jlong apiHandle, jlong trainApiHandle,
     jlong envHandle, jlong optionsHandle, jlong checkpointHandle,
     jstring trainPath, jstring evalPath, jstring optimizerPath);
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_closeSession
    (JNIEnv*, jobject, jlong trainHandle, jlong nativeHandle);

// Name queries
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_getTrainInputNames(JNIEnv*, jobject, jlong, jlong, jlong, jlong);
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_getTrainOutputNames(JNIEnv*, jobject, jlong, jlong, jlong, jlong);
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_getEvalInputNames(JNIEnv*, jobject, jlong, jlong, jlong, jlong);
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_getEvalOutputNames(JNIEnv*, jobject, jlong, jlong, jlong, jlong);

// Training and evaluation
JNIEXPORT jbooleanArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep
    (JNIEnv*, jobject, jlong apiHandle, jlong trainApiHandle,
     jlong nativeHandle, jlong allocatorHandle,
     jobjectArray inputNamesArr, jlongArray inputHandles, jlong numInputs,
     jobjectArray outputNamesArr, jlong numOutputs,
     jobjectArray outputValuesArr, jlongArray outputHandlesArr, jlong runOptionsHandle);
JNIEXPORT jbooleanArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_evalStep
    (JNIEnv*, jobject, jlong apiHandle, jlong trainApiHandle,
     jlong nativeHandle, jlong allocatorHandle,
     jobjectArray inputNamesArr, jlongArray inputHandles, jlong numInputs,
     jobjectArray outputNamesArr, jlong numOutputs,
     jobjectArray outputValuesArr, jlongArray outputHandlesArr, jlong runOptionsHandle);

// Gradient and optimizer
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_lazyResetGrad(JNIEnv*, jobject, jlong, jlong, jlong);
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_optimizerStep(JNIEnv*, jobject, jlong, jlong, jlong, jlong);

// Learning rate
JNIEXPORT void   JNICALL Java_ai_onnxruntime_OrtTrainingSession_setLearningRate(JNIEnv*, jobject, jlong, jlong, jlong, jfloat);
JNIEXPORT jfloat JNICALL Java_ai_onnxruntime_OrtTrainingSession_getLearningRate(JNIEnv*, jobject, jlong, jlong, jlong);

// LR scheduling
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_registerLinearLRScheduler
    (JNIEnv*, jobject, jlong, jlong, jlong, jlong warmupSteps, jlong totalSteps, jfloat initialLR);
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_schedulerStep(JNIEnv*, jobject, jlong, jlong, jlong);

// Seed and export
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_setSeed(JNIEnv*, jclass, jlong, jlong, jlong);
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_exportModelForInference
    (JNIEnv*, jobject, jlong, jlong, jlong, jstring outputPath, jlong numOutputs, jobjectArray outputNamesArr);

Import

#include "OrtJniUtil.h"
#include "onnxruntime/core/session/onnxruntime_c_api.h"
#include "onnxruntime_training_c_api.h"
#include "ai_onnxruntime_OrtTrainingSession.h"

I/O Contract

Inputs

Name Type Description
apiHandle jlong Pointer to the ORT C API struct
trainApiHandle jlong Pointer to the ORT Training API struct
envHandle jlong Pointer to the OrtEnv
optionsHandle jlong Pointer to OrtSessionOptions
checkpointHandle jlong Pointer to OrtCheckpointState
nativeHandle jlong Pointer to the native OrtTrainingSession
allocatorHandle jlong Pointer to the ORT allocator
trainPath jstring File path to the training model
evalPath jstring File path to the evaluation model (nullable)
optimizerPath jstring File path to the optimizer model (nullable)
inputNamesArr jobjectArray String array of input names
inputHandles jlongArray Array of native OrtValue handles (inputs)
outputNamesArr jobjectArray String array of requested output names
learningRate jfloat Learning rate value
warmupSteps jlong Number of warmup steps for linear LR scheduler
totalSteps jlong Total number of training steps for linear LR scheduler
seed jlong Random seed for reproducibility
outputPath jstring File path for exported inference model

Outputs

Name Type Description
sessionHandle jlong Pointer to the newly created OrtTrainingSession
String[] jobjectArray Train/eval input/output name arrays
boolean[] jbooleanArray Per-output flag indicating ORT memory ownership
learningRate jfloat Current learning rate value

Usage Examples

// Inside JNI: create a training session
jlong trainSession = Java_ai_onnxruntime_OrtTrainingSession_createTrainingSession(
    jniEnv, cls, apiHandle, trainApiHandle,
    envHandle, optsHandle, checkpointHandle,
    trainModelPath, evalModelPath, optimizerModelPath);

// Execute a training step
jbooleanArray ownership = Java_ai_onnxruntime_OrtTrainingSession_trainStep(
    jniEnv, sessionObj, apiHandle, trainApiHandle,
    trainSession, allocHandle,
    inputNames, inputHandles, numInputs,
    outputNames, numOutputs, outputValues, outputHandles, runOptsHandle);

// Reset gradients, run optimizer step
Java_ai_onnxruntime_OrtTrainingSession_lazyResetGrad(
    jniEnv, sessionObj, apiHandle, trainApiHandle, trainSession);
Java_ai_onnxruntime_OrtTrainingSession_optimizerStep(
    jniEnv, sessionObj, apiHandle, trainApiHandle, trainSession, runOptsHandle);

// Export the model for inference
Java_ai_onnxruntime_OrtTrainingSession_exportModelForInference(
    jniEnv, sessionObj, apiHandle, trainApiHandle, trainSession,
    outputPathStr, numOutputs, outputNamesArr);

// Close the training session
Java_ai_onnxruntime_OrtTrainingSession_closeSession(
    jniEnv, sessionObj, trainApiHandle, trainSession);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment