Implementation:Microsoft Onnxruntime CPU GRU Grad
| Knowledge Sources | |
|---|---|
| Domains | Training, CPU_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for computing GRU gradients during the backward pass on CPU in the ONNX Runtime training framework.
Description
This file implements the GRUGrad kernel, which orchestrates the backward pass for GRU training. It parses gradient inputs using gru::GRUGradInputs, allocates gradient outputs via gru::GRUGradOutputs, and delegates the actual gradient computation to gru::GRUGradImpl. The kernel is registered under kMSDomain with opset version 1 and supports float type only.
Usage
This kernel is invoked during the backward pass of GRU training, consuming the gate activations (zrh) and hidden states produced by the GRUTraining forward kernel to compute parameter gradients.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cpu/rnn/gru_grad.cc
- Lines: 1-46
Signature
template <typename T>
Status GRUGrad<T>::Compute(OpKernelContext* context) const;
Import
#include "orttraining/orttraining/training_ops/cpu/rnn/gru_grad.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| X | Tensor(float) | Yes | Input sequence [seq_length, batch_size, input_size] |
| W | Tensor(float) | Yes | Weights [directions, 3*H, input_size] |
| R | Tensor(float) | Yes | Recurrence weights [directions, 3*H, H] |
| B | Tensor(float) | No | Bias [directions, 6*H] |
| SL | Tensor(int) | No | Sequence lengths |
| H0 | Tensor(float) | No | Initial hidden state |
| HAll | Tensor(float) | Yes | All hidden states from forward |
| ZRH | Tensor(float) | Yes | Gate activations from forward |
| grad_HAll | Tensor(float) | No | Gradient w.r.t. all hidden states |
| grad_Ht | Tensor(float) | No | Gradient w.r.t. final hidden state |
Outputs
| Name | Type | Description |
|---|---|---|
| dX | Tensor(float) | Gradient w.r.t. input X |
| dW | Tensor(float) | Gradient w.r.t. weights W |
| dR | Tensor(float) | Gradient w.r.t. recurrence weights R |
| dB | Tensor(float) | Gradient w.r.t. bias |
| dH0 | Tensor(float) | Gradient w.r.t. initial hidden state |
Usage Examples
ONNX_OPERATOR_TYPED_KERNEL_EX(
GRUGrad, kMSDomain, 1, float, kCpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
GRUGrad<float>);