Overview
Concrete tool for computing cross entropy loss and its gradient on CPU in the ONNX Runtime training framework.
Description
This file implements four cross entropy loss kernels for CPU training: SoftmaxCrossEntropy, SoftmaxCrossEntropyGrad, SparseSoftmaxCrossEntropy, and SparseSoftmaxCrossEntropyGrad. The shared helper function ComputeShareSoftmaxCrossEntropyCPU computes numerically stable log-probabilities using the log-sum-exp trick (subtracting the row-wise max before exponentiation). The forward kernels compute the loss by combining softmax with cross entropy in a single pass, supporting both dense (one-hot) labels and sparse (integer index) labels. The gradient kernels compute the backpropagation as prob - label (dense) or exp(log_prob) - one_hot(label) (sparse), scaled by the upstream gradient. Both MEAN and SUM reduction modes are supported. The sparse variants optionally accept per-sample weights.
Usage
These kernels are invoked during training when the computation graph contains SoftmaxCrossEntropy or SparseSoftmaxCrossEntropy nodes. The forward pass computes the loss for a batch, and the gradient kernels are used in the backward pass to propagate gradients from the loss back through the softmax layer.
Code Reference
Source Location
Signature
template <typename T>
void ComputeShareSoftmaxCrossEntropyCPU(const int nd, const int c,
const Eigen::Index nd_c,
const T* logit_data,
T* shifted_logit,
T* log_prob_data);
template <typename T>
Status SoftmaxCrossEntropy<T>::Compute(OpKernelContext* context) const;
template <typename T>
Status SoftmaxCrossEntropyGrad<T>::Compute(OpKernelContext* context) const;
template <typename T>
Status SparseSoftmaxCrossEntropy<T>::Compute(OpKernelContext* context) const;
template <typename T>
Status SparseSoftmaxCrossEntropyGrad<T>::Compute(OpKernelContext* context) const;
Import
#include "orttraining/orttraining/training_ops/cpu/loss/cross_entropy.h"
I/O Contract
Inputs (SoftmaxCrossEntropy)
| Name |
Type |
Required |
Description
|
| logit |
Tensor(float) |
Yes |
Raw logit values with shape [N, D]
|
| label |
Tensor(float) |
Yes |
One-hot encoded labels with same shape as logit
|
Outputs (SoftmaxCrossEntropy)
| Name |
Type |
Description
|
| loss |
Tensor(float) |
Scalar loss value (mean or sum reduction)
|
| log_prob |
Tensor(float) |
Log probabilities with same shape as logit
|
Inputs (SparseSoftmaxCrossEntropy)
| Name |
Type |
Required |
Description
|
| logit |
Tensor(float) |
Yes |
Raw logit values with shape [N, D]
|
| label |
Tensor(int64) |
Yes |
Sparse integer class labels with shape [N]
|
| weight |
Tensor(float) |
No |
Optional per-sample weights with shape [N]
|
Outputs (SparseSoftmaxCrossEntropy)
| Name |
Type |
Description
|
| loss |
Tensor(float) |
Scalar loss value
|
| log_prob |
Tensor(float) |
Log probabilities with same shape as logit
|
Usage Examples
// Kernel registration for SoftmaxCrossEntropy (forward)
ONNX_OPERATOR_KERNEL_EX(
SoftmaxCrossEntropy,
kMSDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
SoftmaxCrossEntropy<float>);
// Kernel registration for SparseSoftmaxCrossEntropy (forward)
ONNX_OPERATOR_KERNEL_EX(
SparseSoftmaxCrossEntropy,
kOnnxDomain,
9,
kCpuExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
SparseSoftmaxCrossEntropy<float>);
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.