Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft Onnxruntime CUDA TrainingSplit

From Leeroopedia


Knowledge Sources
Domains Training, CUDA_Kernels
Last Updated 2026-02-10 04:00 GMT

Overview

Concrete tool for splitting a tensor along an axis with runtime-specified split sizes in the ONNX Runtime CUDA training framework.

Description

Implements the SplitTraining operator for CUDA that splits an input tensor into multiple output tensors along a specified axis using split sizes provided as a runtime input (not an attribute). The split sizes are read from a 1-D CPU tensor input. The implementation delegates to PrepareForTrainingCompute for dimension calculations, then uses optimized GPU kernels: SplitSameSplitDimImpl when all splits have equal size (with a fast path for 32 or fewer outputs passing pointers by value), and SplitImpl for variable-size splits using GPU-transferred metadata. Output pointers are managed via CudaAsyncBuffer. Supports all fixed-size tensor types.

Usage

Invoked during training when tensor splitting with dynamic split sizes is needed, such as the gradient of ConcatTraining.

Code Reference

Source Location

Signature

class SplitTraining : public CudaKernel {
  Status ComputeInternal(OpKernelContext* ctx) const;
};

Import

#include "orttraining/training_ops/cuda/tensor/split.h"

I/O Contract

Inputs

Name Type Required Description
input Tensor(T) Yes Input tensor to split
split_sizes Tensor(int64_t) Yes 1-D tensor specifying size of each split (CPU memory)

Outputs

Name Type Description
outputs Tensor(T)... Multiple output tensors, one per split

Usage Examples

ONNX_OPERATOR_KERNEL_EX(SplitTraining, kMSDomain, 1, kCudaExecutionProvider,
    (*KernelDefBuilder::Create())
        .InputMemoryType(OrtMemTypeCPUInput, 1)
        .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
    SplitTraining);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment