Implementation:Microsoft Onnxruntime CUDA TrainingConcat
| Knowledge Sources | |
|---|---|
| Domains | Training, CUDA_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for concatenating tensors along an axis with per-input length output in the ONNX Runtime CUDA training framework.
Description
Implements the ConcatTraining operator for CUDA that extends the standard concat with an additional output providing per-input lengths along the concatenation axis (useful for the gradient of split). The implementation uses optimized GPU kernels: ConcatSameConcatDimImpl when all inputs have the same size along the concat axis (with a fast path for 32 or fewer inputs passing pointers by value), and ConcatImpl for variable-size inputs. Input pointers and metadata are transferred to GPU via CudaAsyncBuffer. The per_input_length output tensor contains the size of each input along the concatenation axis and is placed in CPU memory. Supports all fixed-size tensor types.
Usage
Invoked during training forward pass where tensor concatenation is needed, with the per-input-length output consumed by the corresponding SplitTraining gradient operator.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cuda/tensor/concat.cc
- Lines: 1-119
Signature
class ConcatTraining : public CudaKernel {
Status ComputeInternal(OpKernelContext* ctx) const;
};
Import
#include "orttraining/training_ops/cuda/tensor/concat.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| inputs | Tensor(T)... | Yes | Variadic input tensors to concatenate |
Outputs
| Name | Type | Description |
|---|---|---|
| output | Tensor(T) | Concatenated output tensor |
| per_input_length | Tensor(int64_t) | Per-input size along concatenation axis (CPU memory) |
Usage Examples
ONNX_OPERATOR_KERNEL_EX(ConcatTraining, kMSDomain, 1, kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.OutputMemoryType(OrtMemTypeCPUInput, 1)
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
ConcatTraining);