Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft Onnxruntime CUDA SoftmaxGrad

From Leeroopedia


Knowledge Sources
Domains Training, CUDA_Kernels
Last Updated 2026-02-10 04:00 GMT

Overview

Concrete tool for computing softmax and log-softmax gradients in the ONNX Runtime CUDA training framework.

Description

Implements the SoftmaxGrad operator for CUDA that computes gradients for both Softmax and LogSoftmax operations. Four kernel variants are registered: SoftmaxGrad, SoftmaxGrad_13 (opset 13+), LogSoftmaxGrad, and LogSoftmaxGrad_13. The opset-13 variants support axis selection and use transpose to move the softmax axis to the innermost dimension when needed. The implementation uses DispatchSoftmaxGradImpl to dispatch to SoftmaxGradImpl via cuDNN's softmax backward. For opset-13 variants with non-innermost axis, input tensors (dY, Y) are transposed before computation, and the result (dX) is transposed back. The axis permutation swaps the target axis with the last dimension. Supports float, double, MLFloat16, and BFloat16.

Usage

Invoked during the backward pass when the model uses Softmax or LogSoftmax layers.

Code Reference

Source Location

Signature

class SoftmaxGrad : public CudaKernel {
  Status ComputeInternal(OpKernelContext* ctx) const;
};

Import

#include "orttraining/training_ops/cuda/math/softmax_grad.h"

I/O Contract

Inputs

Name Type Required Description
dY Tensor(T) Yes Upstream gradient
Y Tensor(T) Yes Softmax output from forward pass

Outputs

Name Type Description
dX Tensor(T) Gradient with respect to softmax input

Usage Examples

// All four variants registered with same kernel class
REGISTER_SOFTMAX_GRAD_KERNEL(SoftmaxGrad)
REGISTER_SOFTMAX_GRAD_KERNEL(SoftmaxGrad_13)
REGISTER_SOFTMAX_GRAD_KERNEL(LogSoftmaxGrad)
REGISTER_SOFTMAX_GRAD_KERNEL(LogSoftmaxGrad_13)
// Supports float, double, MLFloat16, BFloat16

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment