Implementation:Microsoft Onnxruntime CPU GRU GradCompute
| Knowledge Sources | |
|---|---|
| Domains | Training, CPU_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for computing GRU (Gated Recurrent Unit) gradient backpropagation on CPU in the ONNX Runtime training framework.
Description
This file implements the GRUGradImpl class template which performs the backward pass for a GRU cell. The gradient computation iterates in reverse through the time steps, computing gradients for the update gate (z), reset gate (r), and hidden state candidate (h) using the chain rule. It supports both the standard GRU formulation and the linear_before_reset variant. The implementation computes gradients for all learnable parameters: input weights (W), recurrence weights (R), biases (Wb, Rb), the input sequence (X), and the initial hidden state (H0). Matrix operations are performed using GEMM calls, and gradients are accumulated across batch elements and time steps using elementwise summation helpers from the deepcpu namespace.
Usage
This kernel is used during the backward pass of GRU training. It is called by the GRUGrad operator to compute parameter gradients after the forward pass has produced the intermediate gate activations (z, r, h) and all hidden states.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cpu/rnn/gru_grad_compute.cc
- Lines: 1-443
Signature
template <typename T>
GRUGradImpl<T>::GRUGradImpl(int sequence_length, int batch_size,
int hidden_size, int input_size,
bool linear_before_reset,
concurrency::ThreadPool* thread_pool,
AllocatorPtr allocator);
template <typename T>
void GRUGradImpl<T>::ComputeGradient(const GRUGradInputs<T>& inputs,
GRUGradOutputs<T>& outputs);
Import
#include "orttraining/orttraining/training_ops/cpu/rnn/gru_grad_compute.h"
I/O Contract
Inputs (GRUGradInputs)
| Name | Type | Required | Description |
|---|---|---|---|
| weights | span<const T> | Yes | Input weights W [3*H, I] |
| recurrence_weights | span<const T> | Yes | Recurrence weights R [3*H, H] |
| bias | span<const T> | No | Bias [6*H] |
| input | span<const T> | Yes | Input sequence [seq_len * batch * input_size] |
| zrh | span<const T> | Yes | Gate activations from forward pass [seq_len * batch * 3*H] |
| all_hidden_states | span<const T> | Yes | All hidden states from forward [seq_len * batch * H] |
| initial_hidden_state | span<const T> | Yes | Initial hidden state [batch * H] |
| grad_all_hidden_states | span<const T> | No | Gradient w.r.t. all hidden states |
| grad_final_hidden_state | span<const T> | No | Gradient w.r.t. final hidden state |
Outputs (GRUGradOutputs)
| Name | Type | Description |
|---|---|---|
| grad_input | span<T> | Gradient w.r.t. input X |
| grad_weights | span<T> | Gradient w.r.t. weights W |
| grad_recurrence_weights | span<T> | Gradient w.r.t. recurrence weights R |
| grad_bias | span<T> | Gradient w.r.t. bias |
| grad_initial_hidden_state | span<T> | Gradient w.r.t. initial hidden state H0 |
Usage Examples
// Creating and using GRUGradImpl
gru::GRUGradImpl<float> gru_cell(sequence_length, batch_size,
hidden_size, input_size,
linear_before_reset,
thread_pool, allocator);
gru_cell.ComputeGradient(grugrad_inputs, grugrad_outputs);