Implementation:Microsoft Onnxruntime CUDA PadAndUnflatten
| Knowledge Sources | |
|---|---|
| Domains | Training, CUDA_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for padding and unflattening tensors in the ONNX Runtime CUDA training framework.
Description
Implements the PadAndUnflatten operator for CUDA, the inverse of FlattenAndUnpad. Given a flattened input tensor, an index tensor, and the original 2-D unflatten dimensions, it creates a zero-initialized output tensor with the original padded shape and scatters the input values to the positions indicated by the indices. The output tensor has shape [dim0, dim1, D1, ..., Dk] where dim0 and dim1 come from unflatten_dims (CPU input) and D1..Dk are trailing dimensions of the input. The implementation first zeros the output buffer with cudaMemsetAsync, then uses PadAndUnflattenImpl with fast_divmod for efficient index computation. Supports MLFloat16, float, double, and BFloat16.
Usage
Used during training backward pass to restore padding to sequences that were unpadded during the forward pass, typically in attention mechanisms.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc
- Lines: 1-92
Signature
class PadAndUnflatten : public CudaKernel {
Status ComputeInternal(OpKernelContext* context) const;
};
Import
#include "orttraining/training_ops/cuda/tensor/pad_and_unflatten.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input | Tensor(T) | Yes | Flattened input tensor |
| indices | Tensor(int64_t) | Yes | 1-D indices into flattened first two dims |
| unflatten_dims | Tensor(int64_t) | Yes | 2-element tensor with [dim0, dim1] (CPU memory) |
Outputs
| Name | Type | Description |
|---|---|---|
| output | Tensor(T) | Padded and unflattened output with shape [dim0, dim1, D1..Dk] |
Usage Examples
ONNX_OPERATOR_KERNEL_EX(PadAndUnflatten, kMSDomain, 1, kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", BuildKernelDefConstraints<MLFloat16, float, double, BFloat16>())
.TypeConstraint("T_INT", DataTypeImpl::GetTensorType<int64_t>())
.TypeConstraint("T_INDEX", DataTypeImpl::GetTensorType<int64_t>())
.InputMemoryType(OrtMemTypeCPUInput, 2),
PadAndUnflatten);