Implementation:Microsoft Onnxruntime OrtTrainingSession JNI
| Knowledge Sources | Description |
|---|---|
| Source File | java/src/main/native/ai_onnxruntime_OrtTrainingSession.c |
| Repository | Microsoft/onnxruntime |
Domains
- JNI Native Bridge
- Training Session Management
- Model Training Operations
Overview
ai_onnxruntime_OrtTrainingSession.c is the JNI native implementation file for the OrtTrainingSession Java class. It provides native methods for creating training sessions, executing training and evaluation steps, running optimizer steps, managing learning rates and schedulers, querying train/eval input/output names, exporting trained models for inference, and releasing training session resources.
Description
Key function groups:
- Session creation:
createTrainingSessioncreates a training session from checkpoint, train model path, optional eval model path, and optional optimizer model path. Uses platform-specific string handling: wide-char (wchar_t*) on Win32 withcopyAndPad, UTF-8 on Unix. ThecopyAndPadhelper function null-terminates Java string chars for Windows APIs. - Name queries:
getTrainInputNames,getTrainOutputNames,getEvalInputNames,getEvalOutputNamesreturnString[]arrays of input/output names for the training and evaluation models via the OrtTrainingApi. - Training step:
trainStepexecutes a forward+backward pass. It marshals input names and OrtValue handles into C arrays, callstrainApi->TrainStep, converts output OrtValues to Java OnnxValue objects, and returns a boolean array indicating ORT memory ownership. - Evaluation step:
evalStepexecutes a forward-only evaluation pass. The structure mirrorstrainStepbut callstrainApi->EvalStep. - Gradient reset:
lazyResetGradlazily resets accumulated gradients viatrainApi->LazyResetGrad. - Optimizer:
optimizerStepapplies a gradient update usingtrainApi->OptimizerStepwith optional run options. - Learning rate:
setLearningRateandgetLearningRatecontrol the optimizer learning rate. - LR scheduling:
registerLinearLRSchedulerregisters a linear warmup+decay scheduler with warmup steps, total steps, and initial learning rate.schedulerStepadvances the scheduler by one step. - Random seed:
setSeedsets the global random seed for reproducibility (static method via class handle). - Model export:
exportModelForInferenceexports the training model for inference with specified output names to a file path. Uses platform-specific string handling for the output path. - Session cleanup:
closeSessionreleases the native OrtTrainingSession viatrainApi->ReleaseTrainingSession.
Both trainStep and evalStep use goto-based cleanup patterns to ensure all allocated arrays (input names, output names, Java strings, value pointers) are properly freed on both success and failure paths.
Code Reference
Source Location
// File: java/src/main/native/ai_onnxruntime_OrtTrainingSession.c
Signature
// Session lifecycle
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtTrainingSession_createTrainingSession
(JNIEnv*, jclass, jlong apiHandle, jlong trainApiHandle,
jlong envHandle, jlong optionsHandle, jlong checkpointHandle,
jstring trainPath, jstring evalPath, jstring optimizerPath);
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_closeSession
(JNIEnv*, jobject, jlong trainHandle, jlong nativeHandle);
// Name queries
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_getTrainInputNames(JNIEnv*, jobject, jlong, jlong, jlong, jlong);
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_getTrainOutputNames(JNIEnv*, jobject, jlong, jlong, jlong, jlong);
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_getEvalInputNames(JNIEnv*, jobject, jlong, jlong, jlong, jlong);
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_getEvalOutputNames(JNIEnv*, jobject, jlong, jlong, jlong, jlong);
// Training and evaluation
JNIEXPORT jbooleanArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep
(JNIEnv*, jobject, jlong apiHandle, jlong trainApiHandle,
jlong nativeHandle, jlong allocatorHandle,
jobjectArray inputNamesArr, jlongArray inputHandles, jlong numInputs,
jobjectArray outputNamesArr, jlong numOutputs,
jobjectArray outputValuesArr, jlongArray outputHandlesArr, jlong runOptionsHandle);
JNIEXPORT jbooleanArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_evalStep
(JNIEnv*, jobject, jlong apiHandle, jlong trainApiHandle,
jlong nativeHandle, jlong allocatorHandle,
jobjectArray inputNamesArr, jlongArray inputHandles, jlong numInputs,
jobjectArray outputNamesArr, jlong numOutputs,
jobjectArray outputValuesArr, jlongArray outputHandlesArr, jlong runOptionsHandle);
// Gradient and optimizer
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_lazyResetGrad(JNIEnv*, jobject, jlong, jlong, jlong);
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_optimizerStep(JNIEnv*, jobject, jlong, jlong, jlong, jlong);
// Learning rate
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_setLearningRate(JNIEnv*, jobject, jlong, jlong, jlong, jfloat);
JNIEXPORT jfloat JNICALL Java_ai_onnxruntime_OrtTrainingSession_getLearningRate(JNIEnv*, jobject, jlong, jlong, jlong);
// LR scheduling
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_registerLinearLRScheduler
(JNIEnv*, jobject, jlong, jlong, jlong, jlong warmupSteps, jlong totalSteps, jfloat initialLR);
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_schedulerStep(JNIEnv*, jobject, jlong, jlong, jlong);
// Seed and export
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_setSeed(JNIEnv*, jclass, jlong, jlong, jlong);
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_exportModelForInference
(JNIEnv*, jobject, jlong, jlong, jlong, jstring outputPath, jlong numOutputs, jobjectArray outputNamesArr);
Import
#include "OrtJniUtil.h"
#include "onnxruntime/core/session/onnxruntime_c_api.h"
#include "onnxruntime_training_c_api.h"
#include "ai_onnxruntime_OrtTrainingSession.h"
I/O Contract
Inputs
| Name | Type | Description |
|---|---|---|
| apiHandle | jlong | Pointer to the ORT C API struct |
| trainApiHandle | jlong | Pointer to the ORT Training API struct |
| envHandle | jlong | Pointer to the OrtEnv |
| optionsHandle | jlong | Pointer to OrtSessionOptions |
| checkpointHandle | jlong | Pointer to OrtCheckpointState |
| nativeHandle | jlong | Pointer to the native OrtTrainingSession |
| allocatorHandle | jlong | Pointer to the ORT allocator |
| trainPath | jstring | File path to the training model |
| evalPath | jstring | File path to the evaluation model (nullable) |
| optimizerPath | jstring | File path to the optimizer model (nullable) |
| inputNamesArr | jobjectArray | String array of input names |
| inputHandles | jlongArray | Array of native OrtValue handles (inputs) |
| outputNamesArr | jobjectArray | String array of requested output names |
| learningRate | jfloat | Learning rate value |
| warmupSteps | jlong | Number of warmup steps for linear LR scheduler |
| totalSteps | jlong | Total number of training steps for linear LR scheduler |
| seed | jlong | Random seed for reproducibility |
| outputPath | jstring | File path for exported inference model |
Outputs
| Name | Type | Description |
|---|---|---|
| sessionHandle | jlong | Pointer to the newly created OrtTrainingSession |
| String[] | jobjectArray | Train/eval input/output name arrays |
| boolean[] | jbooleanArray | Per-output flag indicating ORT memory ownership |
| learningRate | jfloat | Current learning rate value |
Usage Examples
// Inside JNI: create a training session
jlong trainSession = Java_ai_onnxruntime_OrtTrainingSession_createTrainingSession(
jniEnv, cls, apiHandle, trainApiHandle,
envHandle, optsHandle, checkpointHandle,
trainModelPath, evalModelPath, optimizerModelPath);
// Execute a training step
jbooleanArray ownership = Java_ai_onnxruntime_OrtTrainingSession_trainStep(
jniEnv, sessionObj, apiHandle, trainApiHandle,
trainSession, allocHandle,
inputNames, inputHandles, numInputs,
outputNames, numOutputs, outputValues, outputHandles, runOptsHandle);
// Reset gradients, run optimizer step
Java_ai_onnxruntime_OrtTrainingSession_lazyResetGrad(
jniEnv, sessionObj, apiHandle, trainApiHandle, trainSession);
Java_ai_onnxruntime_OrtTrainingSession_optimizerStep(
jniEnv, sessionObj, apiHandle, trainApiHandle, trainSession, runOptsHandle);
// Export the model for inference
Java_ai_onnxruntime_OrtTrainingSession_exportModelForInference(
jniEnv, sessionObj, apiHandle, trainApiHandle, trainSession,
outputPathStr, numOutputs, outputNamesArr);
// Close the training session
Java_ai_onnxruntime_OrtTrainingSession_closeSession(
jniEnv, sessionObj, trainApiHandle, trainSession);
Related Pages
- OrtTrainingSession.java - Java-side training session class calling these natives
- ai_onnxruntime_OrtSession.c - Inference session JNI (similar run pattern)
- OrtJniUtil.c - Shared utility functions (checkOrtStatus, convertOrtValueToONNXValue, allocarray)
- ai_onnxruntime_OrtSession_SessionOptions.c - Session options passed to createTrainingSession