Implementation:Microsoft Onnxruntime CUDA TrainingSplit
| Knowledge Sources | |
|---|---|
| Domains | Training, CUDA_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for splitting a tensor along an axis with runtime-specified split sizes in the ONNX Runtime CUDA training framework.
Description
Implements the SplitTraining operator for CUDA that splits an input tensor into multiple output tensors along a specified axis using split sizes provided as a runtime input (not an attribute). The split sizes are read from a 1-D CPU tensor input. The implementation delegates to PrepareForTrainingCompute for dimension calculations, then uses optimized GPU kernels: SplitSameSplitDimImpl when all splits have equal size (with a fast path for 32 or fewer outputs passing pointers by value), and SplitImpl for variable-size splits using GPU-transferred metadata. Output pointers are managed via CudaAsyncBuffer. Supports all fixed-size tensor types.
Usage
Invoked during training when tensor splitting with dynamic split sizes is needed, such as the gradient of ConcatTraining.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cuda/tensor/split.cc
- Lines: 1-129
Signature
class SplitTraining : public CudaKernel {
Status ComputeInternal(OpKernelContext* ctx) const;
};
Import
#include "orttraining/training_ops/cuda/tensor/split.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input | Tensor(T) | Yes | Input tensor to split |
| split_sizes | Tensor(int64_t) | Yes | 1-D tensor specifying size of each split (CPU memory) |
Outputs
| Name | Type | Description |
|---|---|---|
| outputs | Tensor(T)... | Multiple output tensors, one per split |
Usage Examples
ONNX_OPERATOR_KERNEL_EX(SplitTraining, kMSDomain, 1, kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 1)
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
SplitTraining);