Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft Onnxruntime CUDA TrainingConcat

From Leeroopedia
Revision as of 15:45, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Microsoft_Onnxruntime_CUDA_TrainingConcat.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Training, CUDA_Kernels
Last Updated 2026-02-10 04:00 GMT

Overview

Concrete tool for concatenating tensors along an axis with per-input length output in the ONNX Runtime CUDA training framework.

Description

Implements the ConcatTraining operator for CUDA that extends the standard concat with an additional output providing per-input lengths along the concatenation axis (useful for the gradient of split). The implementation uses optimized GPU kernels: ConcatSameConcatDimImpl when all inputs have the same size along the concat axis (with a fast path for 32 or fewer inputs passing pointers by value), and ConcatImpl for variable-size inputs. Input pointers and metadata are transferred to GPU via CudaAsyncBuffer. The per_input_length output tensor contains the size of each input along the concatenation axis and is placed in CPU memory. Supports all fixed-size tensor types.

Usage

Invoked during training forward pass where tensor concatenation is needed, with the per-input-length output consumed by the corresponding SplitTraining gradient operator.

Code Reference

Source Location

Signature

class ConcatTraining : public CudaKernel {
  Status ComputeInternal(OpKernelContext* ctx) const;
};

Import

#include "orttraining/training_ops/cuda/tensor/concat.h"

I/O Contract

Inputs

Name Type Required Description
inputs Tensor(T)... Yes Variadic input tensors to concatenate

Outputs

Name Type Description
output Tensor(T) Concatenated output tensor
per_input_length Tensor(int64_t) Per-input size along concatenation axis (CPU memory)

Usage Examples

ONNX_OPERATOR_KERNEL_EX(ConcatTraining, kMSDomain, 1, kCudaExecutionProvider,
    (*KernelDefBuilder::Create())
        .OutputMemoryType(OrtMemTypeCPUInput, 1)
        .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
    ConcatTraining);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment