Implementation:Microsoft Onnxruntime CPU Dropout7
| Knowledge Sources | |
|---|---|
| Domains | Training, CPU_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for implementing ONNX opset 7 Dropout with training support on CPU in the ONNX Runtime training framework.
Description
This file implements the TrainableDropout kernel for ONNX opset 7 compatibility in training mode. Unlike the inference-mode Dropout that simply passes data through, this kernel actually applies random dropout during training by generating a boolean mask from a uniform random distribution. Each element is independently dropped with probability ratio (the complement of the keep probability). The kernel uses the PhiloxGenerator for reproducible random number generation. Elements that survive dropout are scaled by 1/(1-ratio) to maintain expected values. The mask is output so it can be reused in the gradient computation.
Usage
This kernel is invoked when opset 7 Dropout nodes appear in a training graph. It generates a dropout mask during the forward pass, which the corresponding DropoutGrad kernel uses during backpropagation to zero out the same elements.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cpu/nn/dropout_7.cc
- Lines: 1-57
Signature
template <typename T>
Status TrainableDropout<T>::Compute(OpKernelContext* context) const;
Import
#include "orttraining/orttraining/training_ops/cpu/nn/dropout_7.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| data | Tensor(float) | Yes | Input tensor to apply dropout on |
Outputs
| Name | Type | Description |
|---|---|---|
| output | Tensor(float) | Output with dropout applied (scaled by 1/(1-ratio)) |
| mask | Tensor(bool) | Boolean mask indicating which elements survived |
Usage Examples
ONNX_OPERATOR_TYPED_KERNEL_EX(
TrainableDropout, kOnnxDomain, 7, float, kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
TrainableDropout<float>);