Implementation:Microsoft Onnxruntime CPU TrainingConcat
| Knowledge Sources | |
|---|---|
| Domains | Training, CPU_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for training-specific tensor concatenation on CPU in the ONNX Runtime training framework.
Description
This file implements the ConcatTraining kernel, which extends the standard ONNX Concat operator with an additional output that records the per-input split sizes along the concatenation axis. This information is needed by the SplitTraining operator during the backward pass to reconstruct the individual gradient tensors. The kernel collects all input tensors, validates them via PrepareForCompute, and outputs both the concatenated tensor and a 1D tensor of per-input lengths along the concatenation axis. If the output would be empty, the kernel returns early.
Usage
This kernel is used during the forward pass of training graphs wherever tensor concatenation is required. The per-input-length output enables the corresponding backward pass (SplitTraining) to correctly split the gradient.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cpu/tensor/concat.cc
- Lines: 1-56
Signature
Status ConcatTraining::Compute(OpKernelContext* ctx) const;
Import
#include "orttraining/orttraining/training_ops/cpu/tensor/concat.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| inputs (variadic) | Tensor(T) | Yes | Variable number of tensors to concatenate |
Outputs
| Name | Type | Description |
|---|---|---|
| concat_result | Tensor(T) | Concatenated output tensor |
| per_input_length | Tensor(int64) | Size of each input along the concatenation axis |
Usage Examples
ONNX_OPERATOR_KERNEL_EX(
ConcatTraining, kMSDomain, 1, kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllTensorTypes()),
ConcatTraining);