Implementation:Microsoft Onnxruntime CUDA FlattenAndUnpad
| Knowledge Sources | |
|---|---|
| Domains | Training, CUDA_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for flattening and removing padding from tensors in the ONNX Runtime CUDA training framework.
Description
Implements the FlattenAndUnpad operator for CUDA that flattens the first two dimensions of an input tensor and then selects rows by index to remove padding. The input must have at least 2 dimensions and the indices tensor must be 1-D. The output shape is [num_indices, D2, ..., Dk] where D2..Dk are the trailing dimensions of the input. The implementation uses FlattenAndUnpadImpl with fast_divmod for efficient index computation. It also outputs the original first two dimensions as unflatten_dims on CPU memory, which can be used by the inverse operation (PadAndUnflatten). Supports int32_t, int64_t, MLFloat16, float, double, and BFloat16.
Usage
Used during training to remove padding from batched sequences before processing, commonly used in attention mechanisms to avoid wasting computation on padding tokens.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad.cc
- Lines: 1-91
Signature
class FlattenAndUnpad : public CudaKernel {
Status ComputeInternal(OpKernelContext* context) const;
};
Import
#include "orttraining/training_ops/cuda/tensor/flatten_and_unpad.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input | Tensor(T) | Yes | Input tensor with at least 2 dimensions |
| indices | Tensor(int64_t) | Yes | 1-D index tensor selecting rows from flattened first two dims |
Outputs
| Name | Type | Description |
|---|---|---|
| output | Tensor(T) | Unpadded output with shape [num_indices, D2..Dk] |
| unflatten_dims | Tensor(int64_t) | Original first two dimensions [dim0, dim1] (CPU memory) |
Usage Examples
ONNX_OPERATOR_KERNEL_EX(FlattenAndUnpad, kMSDomain, 1, kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", BuildKernelDefConstraints<int32_t, int64_t, MLFloat16, float, double, BFloat16>())
.TypeConstraint("T_INT", DataTypeImpl::GetTensorType<int64_t>())
.OutputMemoryType(OrtMemTypeCPUOutput, 1),
FlattenAndUnpad);