Workflow:Microsoft Onnxruntime On Device Training
| Knowledge Sources | |
|---|---|
| Domains | On_Device_Training, Edge_ML, Model_Fine_Tuning |
| Last Updated | 2026-02-10 04:30 GMT |
Overview
End-to-end process for generating training artifacts from a PyTorch model, training on edge devices using ONNX Runtime's Training API, and exporting the trained model for inference.
Description
This workflow enables on-device model training and fine-tuning using ONNX Runtime's Training API. It starts by exporting a PyTorch model to ONNX format and generating the required training artifacts (training model with gradient graph, evaluation model, optimizer model, and initial checkpoint). The training loop uses ONNX Runtime's CheckpointState, Module, and Optimizer classes to perform forward/backward passes, update weights, and manage learning rate scheduling. The workflow supports checkpointing for training resumption and exports the final trained model for inference deployment.
Usage
Execute this workflow when you need to train or fine-tune a model directly on edge or mobile devices without requiring a full PyTorch/TensorFlow installation. This is ideal for personalization scenarios (adapting a model to user-specific data on-device), privacy-preserving training (data never leaves the device), or resource-constrained environments where the full training framework overhead is prohibitive.
Execution Steps
Step 1: Export Base Model to ONNX
Export the PyTorch model to ONNX format using torch.onnx.export. This creates the forward-only ONNX model that serves as the foundation for artifact generation. The exported model must include all operations needed for the forward pass.
Key considerations:
- Export with dynamic axes for batch dimension flexibility
- Ensure all custom operations are supported in ONNX
- The exported model should not include loss computation
Step 2: Generate Training Artifacts
Use the artifacts.generate_artifacts utility to create all files required for on-device training from the forward-only ONNX model. This generates four artifacts: the training model (forward + backward graph), the evaluation model (forward only, optimized), the optimizer model (weight update graph), and the initial checkpoint (parameter initialization).
Key considerations:
- Specify which parameters require gradients and which are frozen
- Select the loss function type (CrossEntropyLoss, MSELoss, BCEWithLogitsLoss, or custom)
- Choose the optimizer algorithm (AdamW or SGD)
- Custom loss functions extend the onnxblock.Block interface
Step 3: Load Checkpoint State
Initialize a CheckpointState from the generated checkpoint file. This state object holds all trainable and non-trainable parameters, optimizer momentum states, and any custom properties (epoch counter, best score, etc.). For training resumption, load a previously saved checkpoint instead.
Key considerations:
- Checkpoint files use the flatbuffer format for efficient serialization
- Large checkpoints support external data files to avoid the 2GB protobuf limit
- Custom properties can store arbitrary training metadata
Step 4: Create Training Components
Instantiate the Module (loads training and evaluation ONNX models), the Optimizer (initializes optimizer states from the optimizer ONNX model), and optionally a learning rate scheduler (e.g., LinearLRScheduler with warmup). These components are coordinated by the TrainingSession wrapper.
Key considerations:
- Module manages separate sessions for training and evaluation graphs
- Optimizer supports AdamW (two momentum tensors) and SGD (single momentum)
- LinearLRScheduler provides warmup followed by linear decay
Step 5: Execute Training Loop
Run the iterative training loop: for each batch, execute a train step (forward + backward pass), perform an optimizer step (weight update), apply learning rate scheduling, and reset gradients. Periodically evaluate on validation data using the eval step, which runs the forward pass without computing gradients.
Key considerations:
- TrainStep accumulates gradients in parameter objects
- LazyResetGrad efficiently defers gradient zeroing to the next train step
- OptimizerStep reads accumulated gradients and updates parameters
- Evaluation mode uses a separate optimized graph without gradient computation
Step 6: Save Checkpoint
Persist the current training state to a checkpoint file for later resumption. The checkpoint contains all parameter values, optimizer states (momentum buffers), and custom properties. Optionally exclude optimizer state to create a smaller inference-only checkpoint.
Key considerations:
- Save periodically during training for fault tolerance
- Excluding optimizer state reduces checkpoint size significantly
- Custom properties preserve training metadata (epoch, loss, etc.)
Step 7: Export Model for Inference
After training completes, export the final trained model as an inference-only ONNX model. This strips the gradient computation graph, optimizer states, and training-specific operators, producing a compact model ready for deployment.
Key considerations:
- Specify output names to include in the inference model
- The exported model is a standard ONNX model loadable by any ONNX Runtime session
- Model can be further optimized with graph optimization passes