Implementation:Microsoft Onnxruntime CPU TrainingSplit
| Knowledge Sources | |
|---|---|
| Domains | Training, CPU_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for training-specific tensor splitting on CPU in the ONNX Runtime training framework.
Description
This file implements the SplitTraining kernel, which extends the standard ONNX Split operator for training use. Unlike the standard Split that uses an attribute for split sizes, this kernel reads the split sizes from an input tensor (index 1), making it dynamic and compatible with the per-input-length output from ConcatTraining. The PrepareForTrainingCompute helper validates the axis, computes dimension sizes, and verifies that split sizes sum to the input dimension size. The kernel then uses math::CopyMatrix to copy data from the input to each output along the split axis. It supports float, int32, int64, and string types.
Usage
This kernel is used during the backward pass of ConcatTraining. It splits the concatenated gradient back into individual per-input gradients using the split sizes recorded during the forward pass.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cpu/tensor/split.cc
- Lines: 1-153
Signature
Status PrepareForTrainingCompute(const TensorShape& input_shape,
int num_outputs, int64_t& axis, int& before_dims,
int& after_dims_including_split_axis, int& after_dims_excluding_split,
std::vector<int64_t>& split_sizes);
Status SplitTraining::Compute(OpKernelContext* context) const;
template <typename T>
Status SplitTraining::ComputeImpl(OpKernelContext& context, const Tensor& input) const;
Import
#include "orttraining/orttraining/training_ops/cpu/tensor/split.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input | Tensor(T) | Yes | Input tensor to split |
| split | Tensor(int64) | Yes | 1D tensor specifying sizes of each split |
Outputs
| Name | Type | Description |
|---|---|---|
| outputs (variadic) | Tensor(T) | One output tensor per split |
Usage Examples
ONNX_OPERATOR_KERNEL_EX(
SplitTraining, kMSDomain, 1, kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllTensorTypes()),
SplitTraining);