Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Microsoft Onnxruntime CPU TrainingSplit

From Leeroopedia


Knowledge Sources
Domains Training, CPU_Kernels
Last Updated 2026-02-10 04:00 GMT

Overview

Concrete tool for training-specific tensor splitting on CPU in the ONNX Runtime training framework.

Description

This file implements the SplitTraining kernel, which extends the standard ONNX Split operator for training use. Unlike the standard Split that uses an attribute for split sizes, this kernel reads the split sizes from an input tensor (index 1), making it dynamic and compatible with the per-input-length output from ConcatTraining. The PrepareForTrainingCompute helper validates the axis, computes dimension sizes, and verifies that split sizes sum to the input dimension size. The kernel then uses math::CopyMatrix to copy data from the input to each output along the split axis. It supports float, int32, int64, and string types.

Usage

This kernel is used during the backward pass of ConcatTraining. It splits the concatenated gradient back into individual per-input gradients using the split sizes recorded during the forward pass.

Code Reference

Source Location

Signature

Status PrepareForTrainingCompute(const TensorShape& input_shape,
    int num_outputs, int64_t& axis, int& before_dims,
    int& after_dims_including_split_axis, int& after_dims_excluding_split,
    std::vector<int64_t>& split_sizes);

Status SplitTraining::Compute(OpKernelContext* context) const;

template <typename T>
Status SplitTraining::ComputeImpl(OpKernelContext& context, const Tensor& input) const;

Import

#include "orttraining/orttraining/training_ops/cpu/tensor/split.h"

I/O Contract

Inputs

Name Type Required Description
input Tensor(T) Yes Input tensor to split
split Tensor(int64) Yes 1D tensor specifying sizes of each split

Outputs

Name Type Description
outputs (variadic) Tensor(T) One output tensor per split

Usage Examples

ONNX_OPERATOR_KERNEL_EX(
    SplitTraining, kMSDomain, 1, kCpuExecutionProvider,
    KernelDefBuilder()
        .TypeConstraint("T", DataTypeImpl::AllTensorTypes()),
    SplitTraining);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment